Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions colossalai/inference/dynamic_batching/ray_dist_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def step(self):

def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str):
self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt)

def is_running(self):
return self.start_dynamic_batching.is_running()


class Driver:
Expand Down Expand Up @@ -162,3 +165,7 @@ def step(self):

def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str):
ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers])

def is_running(self):
results = ray.get([w.is_running.remote() for w in self.workers])
return any(results)
4 changes: 3 additions & 1 deletion colossalai/inference/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,9 @@ def generate(self, prompts, sampling_params, request_id):
"""
self.add_input(request_id, sampling_params, prompts)
return self.loop_for_fwd()


def is_running(self):
return self.running_batch is not None or self.req_queue.waiting_req_list

def start_dynamic_batching(args, tp_engine, waiting_req_list):
try:
Expand Down
35 changes: 11 additions & 24 deletions colossalai/inference/tensor_parallel/modeling/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,17 +133,11 @@ def bloom_model_forward(
assert hasattr(self, "infer_state")
infer_state = self.infer_state

# Compute alibi tensor: check build_alibi_tensor documentation
seq_length_with_past = seq_length
past_key_values_length = 0
# if self.cache_manager.past_key_values_length > 0:
if infer_state.cache_manager.past_key_values_length > 0:
# update the past key values length in cache manager,
# NOTE use BatchInferState.past_key_values_length instead the one in cache manager
past_key_values_length = infer_state.cache_manager.past_key_values_length
seq_length_with_past = seq_length_with_past + past_key_values_length

# infer_state.cache_manager = self.cache_manager
if infer_state.is_context_stage:
past_key_values_length = 0
else:
past_key_values_length = infer_state.max_len_in_batch - 1

if use_cache and seq_length != 1:
# prefill stage
Expand All @@ -160,21 +154,19 @@ def bloom_model_forward(
infer_state.decode_mem_index = alloc_mem[0]
infer_state.decode_mem_start = alloc_mem[1]
infer_state.decode_mem_end = alloc_mem[2]
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
else:
print(f" *** Encountered allocation non-contiguous")
print(
f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
)
print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}")
infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
# infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
# infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index

if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
attention_mask = torch.ones((batch_size, infer_state.max_len_in_batch), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)

Expand All @@ -195,6 +187,7 @@ def bloom_model_forward(
past_key_values_length=past_key_values_length,
)

infer_state.decode_layer_id = 0
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
Expand Down Expand Up @@ -228,6 +221,7 @@ def custom_forward(*inputs):
infer_state=infer_state,
)

infer_state.decode_layer_id += 1
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
Expand All @@ -247,7 +241,6 @@ def custom_forward(*inputs):
# and update these information in engine.generate after model foward called
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
infer_state.decode_layer_id = 0

if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
Expand Down Expand Up @@ -453,9 +446,6 @@ def bloom_attention_forward(
mem_manager = infer_state.cache_manager
layer_id = infer_state.decode_layer_id

if layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_length # += 1

if infer_state.is_context_stage:
# context process
max_input_len = q_length
Expand Down Expand Up @@ -506,15 +496,12 @@ def bloom_attention_forward(
b_loc,
b_start_loc,
b_seq_len,
infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
alibi,
)

context_layer = output.view(batch_size, q_length, H * D_HEAD)

# update layer id
infer_state.decode_layer_id += 1

# NOTE: always set present as none for now, instead of returning past key value to the next decoding,
# we create the past key value pair from the cache manager
present = None
Expand Down
7 changes: 7 additions & 0 deletions tests/test_infer/test_dynamic_batching/test_ray_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,18 @@ async def get_result(request_id, prompt, sampling_params):
if test_async:
print("test_async: ", test_async)
result = asyncio.run(get_result(request_id, prompt, sampling_params))
assert result is not None
print("result: ", result)
else:
print("test_async: ", test_async)
result = driver.generate(request_id, prompt, sampling_params)
assert result is not None
print("result: ", result)

is_running = None
is_running = driver.is_running()
assert is_running is not None
print("is_running: ", is_running)

def check_ray_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
Expand Down