From e70830119691ad51f8a0795abee8634dc8653151 Mon Sep 17 00:00:00 2001 From: hxwang Date: Thu, 20 Jun 2024 02:53:24 +0000 Subject: [PATCH] [zero] comments and naming --- .../plugin/moe_hybrid_parallel_plugin.py | 2 +- .../low_level/bookkeeping/bucket_store.py | 4 +- colossalai/zero/low_level/low_level_optim.py | 74 ++++++++++--------- tests/test_moe/test_moe_zero_fwd_bwd_optim.py | 2 +- 4 files changed, 43 insertions(+), 39 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 8a2415fab5cb..4b047ae1f10c 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -80,7 +80,7 @@ def __init__( super().__init__( optimizer=optimizer, - pg_param_list=pg_param_list, + pg_to_param_list=pg_param_list, initial_scale=initial_scale, min_scale=min_scale, growth_factor=growth_factor, diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index 5b1776062c48..19d20de2b250 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -15,13 +15,11 @@ def __init__( self, torch_pg: ProcessGroup, reduce_bucket_size: int, - overlap_comm: bool = False, ): super().__init__(torch_pg) self.reduce_bucket_size = reduce_bucket_size self.reset_all() - if overlap_comm: - self.comm_stream = get_accelerator().Stream() + self.comm_stream = get_accelerator().Stream() def reset_all(self) -> None: # init diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index bcfdb44478d3..12ff466dad27 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -29,7 +29,7 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): def __init__( self, num_working_param_groups: int, - grad_stores: Dict[nn.Parameter, GradientStore], + pg_to_grad_store: Dict[ProcessGroup, GradientStore], initial_scale: float = 2**16, min_scale: float = 1, growth_factor: float = 2, @@ -48,10 +48,10 @@ def __init__( max_scale, ) self.num_working_param_groups = num_working_param_groups - self.grad_stores = grad_stores + self.pg_to_grad_store = pg_to_grad_store def check_local_overflow(self) -> bool: - for store in self.grad_stores.values(): + for store in self.pg_to_grad_store.values(): for group_id in range(self.num_working_param_groups): for avg_grad in store.get_working_grads_by_group_id(group_id): if avg_grad is not None and has_inf_or_nan(avg_grad): @@ -65,7 +65,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def __init__( self, optimizer: Optimizer, - pg_param_list: Dict[ProcessGroup, List[nn.Parameter]] = None, + pg_to_param_list: Dict[ProcessGroup, List[nn.Parameter]] = None, initial_scale: int = 2**16, # grad scaler config min_scale: int = 1, growth_factor: float = 2.0, @@ -89,14 +89,14 @@ def __init__( self._logger = get_dist_logger() self._verbose = verbose - if pg_param_list is None: - pg_param_list = {dist.group.WORLD: []} + if pg_to_param_list is None: + pg_to_param_list = {dist.group.WORLD: []} for group in self.optim.param_groups: - pg_param_list[dist.group.WORLD].extend(group["params"]) + pg_to_param_list[dist.group.WORLD].extend(group["params"]) - self.pg_param_list = pg_param_list + self.pg_to_param_list = pg_to_param_list param_to_pg = {} - for grp, param_list in pg_param_list.items(): + for grp, param_list in pg_to_param_list.items(): for p in param_list: assert isinstance(p, nn.Parameter) param_to_pg[p] = grp @@ -148,15 +148,18 @@ def __init__( self.working_to_master_param = dict() # NOTE need to gurantee the order of process group is the same accross all ranks - self.grad_stores = {pg: GradientStore(pg, partition_grad=self._partition_grads) for pg in self.pg_param_list} - # param id to grad store, have to use id(param) as key since it is used in stores - self.pid2grad_store = {id(param): self.grad_stores[param_to_pg[param]] for param in param_to_pg} - self.bucket_stores = { - pg: BucketStore(pg, reduce_bucket_size, overlap_comm=self._overlap_communication) - for pg in self.pg_param_list + # process_group <---> xxx_store + # process_group <---> [param1 param2 ...] + # each process group have its own stores + # param belonging to one process_group will use corresponding store + self.pg_to_grad_store = { + pg: GradientStore(pg, partition_grad=self._partition_grads) for pg in self.pg_to_param_list } + # param id to grad store, have to use id(param) as key since it is used in stores + self.pid_to_grad_store = {id(param): self.pg_to_grad_store[param_to_pg[param]] for param in param_to_pg} + self.pg_to_bucket_store = {pg: BucketStore(pg, reduce_bucket_size) for pg in self.pg_to_param_list} # param id to bucket store, have to use id(param) as key since it is used in stores - self.pid2bucket_store = {id(param): self.bucket_stores[param_to_pg[param]] for param in param_to_pg} + self.pid_to_bucket_store = {id(param): self.pg_to_bucket_store[param_to_pg[param]] for param in param_to_pg} # iterate over the param group in the optimizer # partition these param groups for data parallel training @@ -190,7 +193,7 @@ def __init__( if self._dtype is torch.float16: self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin( self.num_param_groups, - self.grad_stores, + self.pg_to_grad_store, initial_scale=initial_scale, min_scale=min_scale, growth_factor=growth_factor, @@ -231,9 +234,9 @@ def _create_master_param_current_rank(self, param_list): for param in param_list: padding_size = ( - self.pid2bucket_store[id(param)].world_size - - param.numel() % self.pid2bucket_store[id(param)].world_size - ) % self.pid2bucket_store[id(param)].world_size + self.pid_to_bucket_store[id(param)].world_size + - param.numel() % self.pid_to_bucket_store[id(param)].world_size + ) % self.pid_to_bucket_store[id(param)].world_size self.record_param_padding_size(param, padding_size) with torch.no_grad(): @@ -246,9 +249,9 @@ def _create_master_param_current_rank(self, param_list): padding_param = param.data.view(-1) splited_params = padding_param.split( - padding_param.numel() // self.pid2bucket_store[id(param)].world_size + padding_param.numel() // self.pid_to_bucket_store[id(param)].world_size ) - splited_params = splited_params[self.pid2bucket_store[id(param)].local_rank] + splited_params = splited_params[self.pid_to_bucket_store[id(param)].local_rank] # use fp32 when master_weights is True if self._master_weights is True: @@ -288,7 +291,7 @@ def _grad_handler(param, group_id): ####################### def _run_reduction(self): - for bucket_store in self.bucket_stores.values(): + for bucket_store in self.pg_to_bucket_store.values(): if bucket_store.num_elements_in_bucket() <= 0: continue @@ -367,10 +370,13 @@ def _add_grad( param_id: int, rank: int = 0, ) -> None: - if len(self.pid2grad_store[param_id].get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num: - self.pid2grad_store[param_id].append_gradients_by_param_id(grad, group_id, param_id) + if ( + len(self.pid_to_grad_store[param_id].get_partitioned_gradients_by_param_id(group_id, param_id)) + < partition_num + ): + self.pid_to_grad_store[param_id].append_gradients_by_param_id(grad, group_id, param_id) else: - self.pid2grad_store[param_id].add_gradients_by_param_id(grad, rank, group_id, param_id) + self.pid_to_grad_store[param_id].add_gradients_by_param_id(grad, rank, group_id, param_id) def _add_to_bucket(self, param, group_id): param_size = param.numel() @@ -380,13 +386,13 @@ def _add_to_bucket(self, param, group_id): # or got a grad of param from another group # after reduction, the bucket will be empty if ( - self.pid2bucket_store[id(param)].num_elements_in_bucket() + param_size > self._reduce_bucket_size - or group_id != self.pid2bucket_store[id(param)].current_group_id + self.pid_to_bucket_store[id(param)].num_elements_in_bucket() + param_size > self._reduce_bucket_size + or group_id != self.pid_to_bucket_store[id(param)].current_group_id ): self._run_reduction() padding_size = self.get_param_padding_size(param) - self.pid2bucket_store[id(param)].add_param_grad(group_id, param, padding_size) + self.pid_to_bucket_store[id(param)].add_param_grad(group_id, param, padding_size) ################################ # torch.optim.Optimizer methods @@ -429,11 +435,11 @@ def backward_by_grad(self, tensor, grad): get_accelerator().synchronize() def zero_bucket_stores(self): - for bucket_store in self.bucket_stores.values(): + for bucket_store in self.pg_to_bucket_store.values(): bucket_store.reset_all() def zero_grad_stores(self): - for grad_store in self.grad_stores.values(): + for grad_store in self.pg_to_grad_store.values(): grad_store.reset_all_gradients() def zero_grad(self, set_to_none=True): @@ -492,7 +498,7 @@ def step(self, closure=None): # if a working param requires grad and has no grad # it is not 'really' working, e.g. the droped layer # else the splited grad should be attached to the splited param - grad_store = self.pid2grad_store[id(working_param)] + grad_store = self.pid_to_grad_store[id(working_param)] grads = grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) grad_index = 0 if self._partition_grads else grad_store.local_rank if len(grads) > 0: @@ -507,7 +513,7 @@ def step(self, closure=None): # compute norm norm_group = 0 - for grad_store in self.grad_stores.values(): + for grad_store in self.pg_to_grad_store.values(): working_grads = grad_store.get_working_grads_by_group_id(group_id) norm_group += self._compute_grad_norm(pg=grad_store.torch_pg, gradients=working_grads) @@ -840,7 +846,7 @@ def get_padding_map(self) -> Dict[int, Tensor]: return self._padding_map def get_param_grad(self, working_param: nn.Parameter) -> Tensor: - grad_store = self.pid2grad_store[id(working_param)] + grad_store = self.pid_to_grad_store[id(working_param)] partial_grad = grad_store.get_working_grad_by_param_id(id(working_param)) if partial_grad is None: return None diff --git a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py b/tests/test_moe/test_moe_zero_fwd_bwd_optim.py index c0340eb96f70..042b3d8aedc5 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd_optim.py @@ -68,7 +68,7 @@ def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch. zero_optimizer = LowLevelZeroOptimizer( zero_optimizer, - pg_param_list=pg_param_list, + pg_to_param_list=pg_param_list, master_weights=master_weights, initial_scale=1, overlap_communication=False,