diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index fa3c3646a592..0909a643a0c7 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -655,7 +655,6 @@ def __init__( self.param_info = param_info self.stage_manager = model.stage_manager self.shared_params = model.shared_params - self.dp_pg = dp_process_group self.tp_pg = tp_process_group self.pp_pg = pp_process_group if use_pipeline: @@ -718,7 +717,7 @@ def _get_all_working_grads() -> List[Tensor]: """Retrieve all working gradients from different parameter groups.""" all_working_grads = [] for group_id in range(self.num_param_groups): - working_grads = self._grad_store.get_working_grads_by_group_id(group_id) + working_grads = self.get_working_grads_by_group_id(group_id) all_working_grads.extend(working_grads) return all_working_grads @@ -726,7 +725,7 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]: """Identify gradients to be synchronized in the sequence parallelism.""" grads_to_sync = [] for grad in all_working_grads: - param_id_for_grad = self._grad_store.get_param_id_for_grad(grad) + param_id_for_grad = self.get_param_id_for_grad(grad) param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad): grads_to_sync.append(grad) @@ -739,7 +738,7 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]: # Get all working gradients and gradients to be synchronized. all_working_grads = _get_all_working_grads() grads_to_sync = _get_grads_to_sync(all_working_grads) - if self._grad_store.require_grad_sync and grads_to_sync is not None: + if self.require_grad_sync and grads_to_sync is not None: # Synchronize sequence parallelism gradients if required. SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync) else: @@ -763,7 +762,7 @@ def backward(self, loss, retain_graph=False): # Call the superclass backward method to compute gradients. super().backward(loss, retain_graph) - if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: + if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. self._sync_sp_grads() else: @@ -788,14 +787,14 @@ def backward_by_grad(self, tensor, grad): # Call the superclass backward_by_grad method to compute gradients. super().backward_by_grad(tensor, grad) - if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: + if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: # If gradient synchronization is required, sync sequence parallelism gradients. self._sync_sp_grads() else: # If gradient synchronization is is not required, return. return - def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: + def _compute_grad_norm(self, dp_pg, gradients: List[Tensor], norm_type: int = 2) -> float: r""" Compute and return the gradient norm for gradient clipping. @@ -811,7 +810,7 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo if len(gradients) == 0: return 0.0 - dp_size = get_world_size(self.dp_pg) if self.dp_pg is not None else 1 + dp_size = get_world_size(dp_pg) if dp_pg is not None else 1 tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 norm_type = float(norm_type) @@ -842,7 +841,7 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo # However, we still perform the 'all_reduce' operation for the sake of good coding practices. # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' if tp_size > 1: - param_id_for_grad = self._grad_store.get_param_id_for_grad(grad) + param_id_for_grad = self.get_param_id_for_grad(grad) param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value if not is_distributed_tensor(param_for_grad): @@ -856,7 +855,7 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo for shared_param in self.shared_params: if self.stage_manager.stage in shared_param: stage_shared_param = shared_param[self.stage_manager.stage] - working_grad = self._grad_store.get_working_grad_by_param_id(id(stage_shared_param)) + working_grad = self.get_working_grad_by_param_id(id(stage_shared_param)) if grad is working_grad: grad_norm_exponentiated /= len(shared_param) @@ -867,7 +866,7 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo ) if dp_size > 1: # compute norm in dp process group - dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg) + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=dp_pg) if tp_size > 1: # compute norm in tp process group dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) @@ -1305,7 +1304,7 @@ def execute_pipeline( # run with gradients accumulation if model.require_grad_sync == False or ( - isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer._grad_store.require_grad_sync == False + isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False ): return outputs diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py index b18edff5214b..a9bd2cc1b4e9 100644 --- a/colossalai/moe/load_balance.py +++ b/colossalai/moe/load_balance.py @@ -292,7 +292,7 @@ def _swap_expert_param_and_optim( exp_avg_ptr = optim.optim.state[working_weight_ptr]["exp_avg"] exp_avg_sq_ptr = optim.optim.state[working_weight_ptr]["exp_avg_sq"] else: - master_weight_ptr = optim._param_store.working_to_master_param[id(weight)] + master_weight_ptr = optim.working_to_master_param[id(weight)] working_weight_ptr = weight exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"] exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"] @@ -344,7 +344,7 @@ def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None # gate optim should be obtained first gate_shape = self.gate.shape # get master weight and optim - master_gate_weight = optim._param_store.working_to_master_param[id(self.gate)] + master_gate_weight = optim.working_to_master_param[id(self.gate)] gate_exp_avg = optim.optim.state[master_gate_weight]["exp_avg"] gate_exp_avg_sq = optim.optim.state[master_gate_weight]["exp_avg_sq"] # gather diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index e8c469146eba..e24a67f9de3c 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from torch import Tensor @@ -113,7 +113,7 @@ def reset_grads_by_group_id(self, group_id: int): def reset_all_gradients(self): self._grads_of_params = dict() - def get_param_id_for_grad(self, grad: Tensor) -> int: + def get_param_id_for_grad(self, grad: Tensor) -> Optional[int]: """Return the id of a parameter which the gradient slice belongs to Args: @@ -123,4 +123,4 @@ def get_param_id_for_grad(self, grad: Tensor) -> int: int: the id of a parameter which the gradient slice belongs to """ - return self.grad_to_param_mapping[id(grad)] + return self.grad_to_param_mapping.get(id(grad), None) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 12ff466dad27..1e1673117c8d 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -80,6 +80,7 @@ def __init__( overlap_communication: bool = False, partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload + dp_process_group: Optional[ProcessGroup] = None, forced_dtype: Optional[torch.dtype] = None, master_weights: bool = True, # master weights ): @@ -89,16 +90,20 @@ def __init__( self._logger = get_dist_logger() self._verbose = verbose + if dp_process_group is not None and pg_to_param_list is not None: + raise ValueError("dp_process_group and pg_to_param_list should not be provided at the same time.") + if pg_to_param_list is None: - pg_to_param_list = {dist.group.WORLD: []} + unique_dp_group = dist.group.WORLD if dp_process_group is None else dp_process_group + pg_to_param_list = {unique_dp_group: []} for group in self.optim.param_groups: - pg_to_param_list[dist.group.WORLD].extend(group["params"]) + pg_to_param_list[unique_dp_group].extend(group["params"]) self.pg_to_param_list = pg_to_param_list param_to_pg = {} for grp, param_list in pg_to_param_list.items(): for p in param_list: - assert isinstance(p, nn.Parameter) + assert isinstance(p, nn.Parameter), f"got {type(p)}" param_to_pg[p] = grp self.param_to_pg = param_to_pg @@ -515,7 +520,7 @@ def step(self, closure=None): norm_group = 0 for grad_store in self.pg_to_grad_store.values(): working_grads = grad_store.get_working_grads_by_group_id(group_id) - norm_group += self._compute_grad_norm(pg=grad_store.torch_pg, gradients=working_grads) + norm_group += self._compute_grad_norm(dp_pg=grad_store.torch_pg, gradients=working_grads) norm_groups.append(norm_group) @@ -552,7 +557,7 @@ def step(self, closure=None): working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] - def _compute_grad_norm(self, pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float: + def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float: r""" Compute and return the gradient norm for gradient clipping. @@ -575,7 +580,7 @@ def _compute_grad_norm(self, pg: ProcessGroup, gradients: List[Tensor], norm_typ device=get_accelerator().get_current_device(), dtype=torch.float, ) - dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=pg) + dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_pg) total_norm = total_norm_cuda.item() else: @@ -593,7 +598,7 @@ def _compute_grad_norm(self, pg: ProcessGroup, gradients: List[Tensor], norm_typ torch.distributed.all_reduce( total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, - group=pg, + group=dp_pg, ) total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) @@ -854,3 +859,27 @@ def get_param_grad(self, working_param: nn.Parameter) -> Tensor: dist.all_gather(tensor_list, partial_grad, group=grad_store.torch_pg) grad_flat = torch.cat(tensor_list, dim=0) return grad_flat[: working_param.numel()].reshape_as(working_param) + + def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]: + working_grads = [] + for grad_store in self.pg_to_grad_store.values(): + working_grads.extend(grad_store.get_working_grads_by_group_id(group_id)) + return working_grads + + def get_param_id_for_grad(self, grad: Tensor) -> int: + param_id = None + for grad_store in self.pg_to_grad_store.values(): + id_maybe_none = grad_store.get_param_id_for_grad(grad) + if id_maybe_none is not None: + if param_id is not None: + raise ValueError("The grad mapping is not unique") + param_id = id_maybe_none + return param_id + + def get_working_grad_by_param_id(self, param_id: int) -> Tensor: + grad_store = self.pid_to_grad_store[param_id] + return grad_store.get_working_grad_by_param_id(param_id) + + def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List: + grad_store = self.pid_to_grad_store[param_id] + return grad_store.get_partitioned_gradients_by_param_id(group_id, param_id) diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index 24dc4a5d2677..ab48944d4eaa 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -59,10 +59,10 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): # check master weight assert isinstance(new_optimizer, LowLevelZeroOptimizer) working_param_id_set = set(id(p) for p in new_model.parameters()) - for p_id, master_param in new_optimizer._param_store.working_to_master_param.items(): + for p_id, master_param in new_optimizer.working_to_master_param.items(): assert p_id in working_param_id_set - working_param = new_optimizer._param_store.master_to_working_param[id(master_param)] - padding = new_optimizer._param_store.get_param_padding_size(working_param) + working_param = new_optimizer.master_to_working_param[id(master_param)] + padding = new_optimizer.get_param_padding_size(working_param) padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding)) working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()] assert torch.equal( @@ -115,10 +115,10 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo # check master weight assert isinstance(new_optimizer, LowLevelZeroOptimizer) working_param_id_set = set(id(p) for p in new_model.parameters()) - for p_id, master_param in new_optimizer._param_store.working_to_master_param.items(): + for p_id, master_param in new_optimizer.working_to_master_param.items(): assert p_id in working_param_id_set - working_param = new_optimizer._param_store.master_to_working_param[id(master_param)] - padding = new_optimizer._param_store.get_param_padding_size(working_param) + working_param = new_optimizer.master_to_working_param[id(master_param)] + padding = new_optimizer.get_param_padding_size(working_param) padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding)) working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()] assert torch.equal( diff --git a/tests/test_optimizer/_utils.py b/tests/test_optimizer/_utils.py index 313624e83c22..4046e41189ec 100644 --- a/tests/test_optimizer/_utils.py +++ b/tests/test_optimizer/_utils.py @@ -234,7 +234,7 @@ def check_dist_grad(sharded_optimizer, org_model, sharded_model, weight_layer_fo if org_name in weight_layer_for_check: org_grad = org_param.grad group_id = dist.get_rank(sharded_optimizer.optim.dp_group) - dist_grad = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(group_id, id(sharded_param)) + dist_grad = sharded_optimizer.get_partitioned_gradients_by_param_id(group_id, id(sharded_param)) # dist_grad concat then reshape to org_grad shape if dist_grad: diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index 06c254e5650a..2da679d7d5b5 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -316,7 +316,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]): dp_process_group=dp_group, verbose=True, ) - shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened + shard_to_param = dist_optim.master_to_working_param # {id(): param tensor} but flattened dist_optim.optim.setup_distributed( tp_group=tp_group, dp_group=dp_group, diff --git a/tests/test_optimizer/test_dist_came.py b/tests/test_optimizer/test_dist_came.py index c767e968434d..45fe687b724c 100644 --- a/tests/test_optimizer/test_dist_came.py +++ b/tests/test_optimizer/test_dist_came.py @@ -200,7 +200,7 @@ def exam_dist_came_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]): dp_process_group=dp_group, verbose=True, ) - shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened + shard_to_param = dist_optim.master_to_working_param # {id(): param tensor} but flattened dist_optim.optim.setup_distributed( tp_group=tp_group, dp_group=dp_group, diff --git a/tests/test_optimizer/test_dist_lamb.py b/tests/test_optimizer/test_dist_lamb.py index c1ff78c0c276..66e8e49c7801 100644 --- a/tests/test_optimizer/test_dist_lamb.py +++ b/tests/test_optimizer/test_dist_lamb.py @@ -229,7 +229,7 @@ def run_dist_lamb_fwd_bwd( dp_process_group=dp_group, verbose=True, ) - shard_to_param = optim._param_store.master_to_working_param + shard_to_param = optim.master_to_working_param optim.optim.setup_distributed(tp_group, dp_group, shard_to_param, is_zero=True) else: optim.setup_distributed(tp_group) diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py index be257e81860e..e37a050e3dbe 100644 --- a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py +++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py @@ -32,6 +32,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group + dp_group = booster.plugin.dp_group bert = unwrap_model(org_model, "BertModel", "bert") sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") @@ -53,8 +54,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, device = origin_norm.device norm_groups = [] for group_id in range(sharded_optimizer.num_param_groups): - working_grads = sharded_optimizer._grad_store.get_working_grads_by_group_id(group_id) - norm_group = sharded_optimizer._compute_grad_norm(gradients=working_grads) + working_grads = sharded_optimizer.get_working_grads_by_group_id(group_id) + norm_group = sharded_optimizer._compute_grad_norm(dp_group, gradients=working_grads) norm_groups.append(norm_group) total_norm = 0.0 for norm in norm_groups: diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index b73552cecb9e..4d66692a4c11 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -62,10 +62,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): for p1, p2 in zip(command_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): - working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)] - grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p)) + working_p = sharded_optimizer.master_to_working_param[id(p2)] + grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grad_index = ( - 0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank + 0 if sharded_optimizer._partition_grads else sharded_optimizer.pid_to_bucket_store[id(p2)].local_rank ) grad = grads[grad_index] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 3a8a1357deb0..12369289f5a7 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -62,10 +62,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): - working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)] - grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p)) + working_p = sharded_optimizer.master_to_working_param[id(p2)] + grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grad_index = ( - 0 if sharded_optimizer._grad_store._partition_grads else sharded_optimizer._bucket_store.zero_local_rank + 0 if sharded_optimizer._partition_grads else sharded_optimizer.pid_to_bucket_store[id(p2)].local_rank ) grad = grads[grad_index] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]