Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 1 addition & 23 deletions colossalai/zero/low_level/bookkeeping/bucket_store.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from typing import Dict, Optional
from typing import Dict

import torch
import torch.distributed as dist
from torch import Tensor
from torch._utils import _flatten_dense_tensors
from torch.distributed import ProcessGroup

from colossalai.accelerator import get_accelerator

from .base_store import BaseStore


Expand All @@ -16,28 +13,9 @@ def __init__(
self,
torch_pg: ProcessGroup,
reduce_bucket_size: int,
overlap_communication: bool,
communication_dtype: Optional[torch.dtype] = None,
moe_extra_dp_process_group: ProcessGroup = None,
):
super().__init__(torch_pg)
self.reduce_bucket_size = reduce_bucket_size
# communication params
self._overlap_communication = overlap_communication
self._communication_dtype = communication_dtype
if self._overlap_communication:
self.comm_stream = get_accelerator().Stream()
self.zero_local_rank = dist.get_rank(group=self.torch_pg)
self.zero_world_size = dist.get_world_size(group=self.torch_pg)
# extra dp
# This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size.
# Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg.
# Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step.
# And moe working and master param are split by extra dp pg.
self.moe_extra_dp_pg = moe_extra_dp_process_group
if self.moe_extra_dp_pg is not None:
self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg)
self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg)
self.reset_all()

def reset_all(self) -> None:
Expand Down
7 changes: 2 additions & 5 deletions colossalai/zero/low_level/low_level_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ def __init__(
# it will not manage the tensors used by mixed precision training
self._param_store = ParameterStore(process_group)
self._grad_store = GradientStore(process_group, partition_grad=partition_grad)
self._bucket_store = BucketStore(
process_group, reduce_bucket_size=reduce_bucket_size, overlap_communication=overlap_communication
)
self._bucket_store = BucketStore(process_group, reduce_bucket_size=reduce_bucket_size)

# working and master params for mixed precision training
group_params = []
Expand All @@ -85,7 +83,6 @@ def __init__(

# communication params
self._overlap_communication = overlap_communication
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype

# initialize communication stream for
Expand Down Expand Up @@ -172,7 +169,7 @@ def _add_to_bucket(self, param):
# or got a grad of param from another group
# after reduction, the bucket will be empty
if (
self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size
self._bucket_store.num_elements_in_bucket() + param_size > self._bucket_store.reduce_bucket_size
or LowLevelOptStrategy.DEFAULT_STORE_GROUP_ID != self._bucket_store.current_group_id
):
self._run_reduction()
Expand Down