diff --git a/colossalai/inference/async_engine.py b/colossalai/inference/async_engine.py index a58dde01d250..4515aeab3469 100644 --- a/colossalai/inference/async_engine.py +++ b/colossalai/inference/async_engine.py @@ -79,11 +79,12 @@ def _step(self): """ request_outputs = self.driver.step() if request_outputs is not None: + print("request_outputs: ", request_outputs) for request_output in request_outputs: self._request_tracker.process_request_output(request_output) self._request_tracker.add_stop() - def abort(self, request_id: str): + def abort_request(self, request_id: str): self.driver.abort(request_id) def _has_requests_in_progress(self): diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 42ff8bf1e9ef..18226d78c20c 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -42,7 +42,12 @@ def __init__( running_max_req_size = self.engine.max_batch_size if self.engine is not None else 2 self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list) # all the inputs should be put into req_queue: waiting req list - + assert max_total_token_num >= self.engine.max_batch_size * ( + self.engine.max_input_len + self.engine.max_output_len + ), "max_total_token_num should be greater than max_batch_size * (max_input_len+max_output_len)" + assert ( + batch_max_tokens >= self.engine.max_input_len + self.engine.max_output_len + ), "batch_max_tokens should be greater than (max_input_len+max_output_len)" self.running_batch: Batch = running_batch self.eos_id = eos_id self.has_wait_tokens = 0 diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index f7fb7a825694..a98b96565c50 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -61,7 +61,6 @@ def __init__( self.max_input_len = max_input_len self.max_output_len = max_output_len self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len) - # Constraints relatable with specs of devices and model # This may change into an optional arg in the future assert self.max_batch_size <= 64, "Max batch size exceeds the constraint" @@ -380,7 +379,6 @@ def forward(self, batch_id, is_prefill): Forward is used in Dynamic Batching Manager """ batch = self.cache.pop(batch_id) - if is_prefill: input_ = torch.tensor(batch.all_input_ids).cuda() else: diff --git a/tests/test_infer/test_dynamic_batching/config.yaml b/tests/test_infer/test_dynamic_batching/config.yaml index c31ae8c5fadb..59ec39779335 100644 --- a/tests/test_infer/test_dynamic_batching/config.yaml +++ b/tests/test_infer/test_dynamic_batching/config.yaml @@ -5,10 +5,10 @@ engine_config: max_input_len: 128 max_output_len: 32 # config for app router deployment -# Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig? +# Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig. router_config: - max_total_token_num: 42 - batch_max_tokens: 42 + max_total_token_num: 640 + batch_max_tokens: 640 eos_id: 0 disable_log_stats: False log_stats_interval: 10 diff --git a/tests/test_infer/test_dynamic_batching/test_async_engine.py b/tests/test_infer/test_dynamic_batching/test_async_engine.py index 148d325a1d9a..6287699d7b3c 100644 --- a/tests/test_infer/test_dynamic_batching/test_async_engine.py +++ b/tests/test_infer/test_dynamic_batching/test_async_engine.py @@ -23,7 +23,7 @@ def run_async_engine(path: str): if model is None or not os.path.exists(model): return - prompt = "Introduce some landmarks in Beijing" + prompt = "Introduce some landmarks in London.\nThe Tower of London is a historic castle on the north bank of the River Thames in central London. It was founded towards the end of 10" sampling_params = SamplingParams() asyncio.run(asy_for_loop_test(config, prompt, sampling_params)) @@ -32,6 +32,7 @@ async def get_result(engine, prompt, sampling_params): request_id = str(uuid.uuid4().hex) results = engine.generate(request_id, prompt, sampling_params) async for result in results: + # print(result) assert result is not None diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index 5c84b39d8f8e..a840407d5867 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -14,7 +14,6 @@ def run_ray_dist(path: str): - print(f"Using yaml file {path}") if not os.path.exists(path): return config = RayInitConfig.from_yaml_path(path)