From d1e2538aad70eab32c95d3193f680e6bdaaa6281 Mon Sep 17 00:00:00 2001 From: run-qiao <809732792@qq.com> Date: Wed, 13 Jul 2022 18:54:42 +0800 Subject: [PATCH] [NFC] polish colossalai/nn/layer/wrapper/pipeline_wrapper.py code style --- colossalai/nn/layer/wrapper/pipeline_wrapper.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/colossalai/nn/layer/wrapper/pipeline_wrapper.py b/colossalai/nn/layer/wrapper/pipeline_wrapper.py index 20813dc8933d..ef1d794cc68f 100644 --- a/colossalai/nn/layer/wrapper/pipeline_wrapper.py +++ b/colossalai/nn/layer/wrapper/pipeline_wrapper.py @@ -6,6 +6,7 @@ class PipelineSharedModuleWrapper: + def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None: assert len(pipeline_ranks) > 1, f'Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}' self.pipeline_ranks = pipeline_ranks @@ -22,10 +23,7 @@ def _init_group(self): num_pp_stages = num_dp_groups // pp_size for i in range(dp_size): for j in range(num_pp_stages): - pipeline_ranks = list( - range(i * num_dp_groups + j, - (i + 1) * num_dp_groups, - num_pp_stages)) + pipeline_ranks = list(range(i * num_dp_groups + j, (i + 1) * num_dp_groups, num_pp_stages)) sub_ranks = [pipeline_ranks[idx] for idx in self.pipeline_ranks] group = dist.new_group(sub_ranks) if rank in sub_ranks: