Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions colossalai/inference/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(
# Init pg mesh
pg_mesh = ProcessGroupMesh(pp_size, tp_size)

stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True if pp_size * tp_size > 1 else False)
Comment thread
ver217 marked this conversation as resolved.
self.cache_manager_list = [
self._init_manager(model, max_batch_size, max_input_len, max_output_len)
for _ in range(micro_batch_buffer_size or pp_size)
Expand All @@ -142,7 +142,9 @@ def __init__(
self.verbose = verbose
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)

self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))
self.model = self._shardformer(
model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS) if pp_size * tp_size > 1 else None
Comment thread
ver217 marked this conversation as resolved.
)
if quant == "gptq":
self.gptq_manager.post_init_gptq_buffer(self.model)

Expand Down
24 changes: 20 additions & 4 deletions tests/test_infer/test_hybrid_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,32 @@ def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
torch.cuda.empty_cache()


def check_tp_pipeline_inference(rank, world_size, port):
@parameterize("tp_size", [1])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()


def check_tp_pp_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_pipeline_inference_test()


def check_single_inference(rank, world_size, port):
def check_tp_or_pp_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_inference_test()
run_pipeline_inference_test()


def check_single_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_single_inference_test


@pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
Expand All @@ -97,8 +112,9 @@ def check_single_inference(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_pipeline_inference():
spawn(check_tp_pipeline_inference, nprocs=4)
spawn(check_single_inference, nprocs=2)
spawn(check_tp_pp_inference, nprocs=4)
spawn(check_tp_or_pp_inference, nprocs=2)
spawn(check_single_inference, nprocs=1)


if __name__ == "__main__":
Expand Down
24 changes: 20 additions & 4 deletions tests/test_infer/test_hybrid_chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,32 @@ def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
torch.cuda.empty_cache()


def check_tp_pipeline_inference(rank, world_size, port):
@parameterize("tp_size", [1])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()


def check_tp_pp_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_pipeline_inference_test()


def check_single_inference(rank, world_size, port):
def check_tp_or_pp_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_inference_test()
run_pipeline_inference_test()


def check_single_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_single_inference_test


@pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
Expand All @@ -105,8 +120,9 @@ def check_single_inference(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_pipeline_inference():
spawn(check_tp_pipeline_inference, nprocs=4)
spawn(check_single_inference, nprocs=2)
spawn(check_tp_pp_inference, nprocs=4)
spawn(check_tp_or_pp_inference, nprocs=2)
spawn(check_single_inference, nprocs=1)


if __name__ == "__main__":
Expand Down
24 changes: 20 additions & 4 deletions tests/test_infer/test_hybrid_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,32 @@ def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
torch.cuda.empty_cache()


def check_tp_pipeline_inference(rank, world_size, port):
@parameterize("tp_size", [1])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()


def check_tp_pp_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_pipeline_inference_test()


def check_single_inference(rank, world_size, port):
def check_tp_or_pp_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_inference_test()
run_pipeline_inference_test()


def check_single_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_single_inference_test


@pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
Expand All @@ -102,8 +117,9 @@ def check_single_inference(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_pipeline_inference():
spawn(check_tp_pipeline_inference, nprocs=4)
spawn(check_single_inference, nprocs=2)
spawn(check_tp_pp_inference, nprocs=4)
spawn(check_tp_or_pp_inference, nprocs=2)
spawn(check_single_inference, nprocs=1)


if __name__ == "__main__":
Expand Down