From 0e4b076ca0338b29f7297e9d105d01f367a748cb Mon Sep 17 00:00:00 2001 From: hxwang Date: Wed, 12 Jun 2024 05:43:12 +0000 Subject: [PATCH] [zero] remove redundant members in BucketStore --- .../low_level/bookkeeping/bucket_store.py | 24 +------------------ .../zero/low_level/low_level_strategy.py | 7 ++---- 2 files changed, 3 insertions(+), 28 deletions(-) diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index 1496603fabeb..d6898f74e7bd 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -1,13 +1,10 @@ -from typing import Dict, Optional +from typing import Dict import torch -import torch.distributed as dist from torch import Tensor from torch._utils import _flatten_dense_tensors from torch.distributed import ProcessGroup -from colossalai.accelerator import get_accelerator - from .base_store import BaseStore @@ -16,28 +13,9 @@ def __init__( self, torch_pg: ProcessGroup, reduce_bucket_size: int, - overlap_communication: bool, - communication_dtype: Optional[torch.dtype] = None, - moe_extra_dp_process_group: ProcessGroup = None, ): super().__init__(torch_pg) self.reduce_bucket_size = reduce_bucket_size - # communication params - self._overlap_communication = overlap_communication - self._communication_dtype = communication_dtype - if self._overlap_communication: - self.comm_stream = get_accelerator().Stream() - self.zero_local_rank = dist.get_rank(group=self.torch_pg) - self.zero_world_size = dist.get_world_size(group=self.torch_pg) - # extra dp - # This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size. - # Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg. - # Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step. - # And moe working and master param are split by extra dp pg. - self.moe_extra_dp_pg = moe_extra_dp_process_group - if self.moe_extra_dp_pg is not None: - self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg) - self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg) self.reset_all() def reset_all(self) -> None: diff --git a/colossalai/zero/low_level/low_level_strategy.py b/colossalai/zero/low_level/low_level_strategy.py index 7298ef543eae..e45f39cc726d 100644 --- a/colossalai/zero/low_level/low_level_strategy.py +++ b/colossalai/zero/low_level/low_level_strategy.py @@ -66,9 +66,7 @@ def __init__( # it will not manage the tensors used by mixed precision training self._param_store = ParameterStore(process_group) self._grad_store = GradientStore(process_group, partition_grad=partition_grad) - self._bucket_store = BucketStore( - process_group, reduce_bucket_size=reduce_bucket_size, overlap_communication=overlap_communication - ) + self._bucket_store = BucketStore(process_group, reduce_bucket_size=reduce_bucket_size) # working and master params for mixed precision training group_params = [] @@ -85,7 +83,6 @@ def __init__( # communication params self._overlap_communication = overlap_communication - self._reduce_bucket_size = reduce_bucket_size self._communication_dtype = communication_dtype # initialize communication stream for @@ -172,7 +169,7 @@ def _add_to_bucket(self, param): # or got a grad of param from another group # after reduction, the bucket will be empty if ( - self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size + self._bucket_store.num_elements_in_bucket() + param_size > self._bucket_store.reduce_bucket_size or LowLevelOptStrategy.DEFAULT_STORE_GROUP_ID != self._bucket_store.current_group_id ): self._run_reduction()