diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index a40a00e2666c..70cc21436456 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -100,6 +100,9 @@ def step(self): def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt) + + def is_running(self): + return self.start_dynamic_batching.is_running() class Driver: @@ -162,3 +165,7 @@ def step(self): def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) + + def is_running(self): + results = ray.get([w.is_running.remote() for w in self.workers]) + return any(results) diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 30717a915e3b..bd33837dc451 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -257,7 +257,9 @@ def generate(self, prompts, sampling_params, request_id): """ self.add_input(request_id, sampling_params, prompts) return self.loop_for_fwd() - + + def is_running(self): + return self.running_batch is not None or self.req_queue.waiting_req_list def start_dynamic_batching(args, tp_engine, waiting_req_list): try: diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index 27a26caabefa..c10f7e620852 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -133,17 +133,11 @@ def bloom_model_forward( assert hasattr(self, "infer_state") infer_state = self.infer_state - # Compute alibi tensor: check build_alibi_tensor documentation - seq_length_with_past = seq_length - past_key_values_length = 0 - # if self.cache_manager.past_key_values_length > 0: - if infer_state.cache_manager.past_key_values_length > 0: - # update the past key values length in cache manager, - # NOTE use BatchInferState.past_key_values_length instead the one in cache manager - past_key_values_length = infer_state.cache_manager.past_key_values_length - seq_length_with_past = seq_length_with_past + past_key_values_length - # infer_state.cache_manager = self.cache_manager + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 if use_cache and seq_length != 1: # prefill stage @@ -160,21 +154,19 @@ def bloom_model_forward( infer_state.decode_mem_index = alloc_mem[0] infer_state.decode_mem_start = alloc_mem[1] infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index else: print(f" *** Encountered allocation non-contiguous") - print( - f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" - ) + print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}") infer_state.decode_is_contiguous = False alloc_mem = infer_state.cache_manager.alloc(batch_size) infer_state.decode_mem_index = alloc_mem # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + attention_mask = torch.ones((batch_size, infer_state.max_len_in_batch), device=hidden_states.device) else: attention_mask = attention_mask.to(hidden_states.device) @@ -195,6 +187,7 @@ def bloom_model_forward( past_key_values_length=past_key_values_length, ) + infer_state.decode_layer_id = 0 for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -228,6 +221,7 @@ def custom_forward(*inputs): infer_state=infer_state, ) + infer_state.decode_layer_id += 1 hidden_states = outputs[0] if use_cache is True: presents = presents + (outputs[1],) @@ -247,7 +241,6 @@ def custom_forward(*inputs): # and update these information in engine.generate after model foward called infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 - infer_state.decode_layer_id = 0 if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) @@ -453,9 +446,6 @@ def bloom_attention_forward( mem_manager = infer_state.cache_manager layer_id = infer_state.decode_layer_id - if layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_length # += 1 - if infer_state.is_context_stage: # context process max_input_len = q_length @@ -506,15 +496,12 @@ def bloom_attention_forward( b_loc, b_start_loc, b_seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, alibi, ) context_layer = output.view(batch_size, q_length, H * D_HEAD) - # update layer id - infer_state.decode_layer_id += 1 - # NOTE: always set present as none for now, instead of returning past key value to the next decoding, # we create the past key value pair from the cache manager present = None diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index 4cf9881f41dc..0eea9ef16345 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -35,11 +35,18 @@ async def get_result(request_id, prompt, sampling_params): if test_async: print("test_async: ", test_async) result = asyncio.run(get_result(request_id, prompt, sampling_params)) + assert result is not None print("result: ", result) else: print("test_async: ", test_async) result = driver.generate(request_id, prompt, sampling_params) + assert result is not None print("result: ", result) + + is_running = None + is_running = driver.is_running() + assert is_running is not None + print("is_running: ", is_running) def check_ray_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")