Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion colossalai/inference/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,12 @@ def _step(self):
"""
request_outputs = self.driver.step()
if request_outputs is not None:
print("request_outputs: ", request_outputs)
for request_output in request_outputs:
self._request_tracker.process_request_output(request_output)
self._request_tracker.add_stop()

def abort(self, request_id: str):
def abort_request(self, request_id: str):
self.driver.abort(request_id)

def _has_requests_in_progress(self):
Expand Down
7 changes: 6 additions & 1 deletion colossalai/inference/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ def __init__(
running_max_req_size = self.engine.max_batch_size if self.engine is not None else 2
self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list)
# all the inputs should be put into req_queue: waiting req list

assert max_total_token_num >= self.engine.max_batch_size * (
self.engine.max_input_len + self.engine.max_output_len
), "max_total_token_num should be greater than max_batch_size * (max_input_len+max_output_len)"
assert (
batch_max_tokens >= self.engine.max_input_len + self.engine.max_output_len
), "batch_max_tokens should be greater than (max_input_len+max_output_len)"
self.running_batch: Batch = running_batch
self.eos_id = eos_id
self.has_wait_tokens = 0
Expand Down
2 changes: 0 additions & 2 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(
self.max_input_len = max_input_len
self.max_output_len = max_output_len
self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len)

# Constraints relatable with specs of devices and model
# This may change into an optional arg in the future
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
Expand Down Expand Up @@ -380,7 +379,6 @@ def forward(self, batch_id, is_prefill):
Forward is used in Dynamic Batching Manager
"""
batch = self.cache.pop(batch_id)

if is_prefill:
input_ = torch.tensor(batch.all_input_ids).cuda()
else:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_infer/test_dynamic_batching/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ engine_config:
max_input_len: 128
max_output_len: 32
# config for app router deployment
# Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig?
# Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig.
router_config:
max_total_token_num: 42
batch_max_tokens: 42
max_total_token_num: 640
batch_max_tokens: 640
eos_id: 0
disable_log_stats: False
log_stats_interval: 10
Expand Down
3 changes: 2 additions & 1 deletion tests/test_infer/test_dynamic_batching/test_async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def run_async_engine(path: str):
if model is None or not os.path.exists(model):
return

prompt = "Introduce some landmarks in Beijing"
prompt = "Introduce some landmarks in London.\nThe Tower of London is a historic castle on the north bank of the River Thames in central London. It was founded towards the end of 10"
sampling_params = SamplingParams()
asyncio.run(asy_for_loop_test(config, prompt, sampling_params))

Expand All @@ -32,6 +32,7 @@ async def get_result(engine, prompt, sampling_params):
request_id = str(uuid.uuid4().hex)
results = engine.generate(request_id, prompt, sampling_params)
async for result in results:
# print(result)
assert result is not None


Expand Down
1 change: 0 additions & 1 deletion tests/test_infer/test_dynamic_batching/test_ray_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


def run_ray_dist(path: str):
print(f"Using yaml file {path}")
if not os.path.exists(path):
return
config = RayInitConfig.from_yaml_path(path)
Expand Down