From e09c6b7caefc9f171235ce892f540f5484315fdd Mon Sep 17 00:00:00 2001 From: littsk <1214689160@qq.com> Date: Mon, 25 Sep 2023 14:08:06 +0800 Subject: [PATCH] Correct several erroneous code comments --- colossalai/shardformer/policies/base_policy.py | 2 +- colossalai/zero/low_level/bookkeeping/bucket_store.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index e7f199129a00..eb03500531bc 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -50,7 +50,7 @@ def example_replace_weight(module: torch.nn.Module): new_weight = shard_rowwise(weight, process_group) module.weight = torch.nn.Parameter(new_weight) ``` - sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription + sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a SubModuleReplacementDescription object which specifies the module to be replaced and the target module used to replacement. method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement """ diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index 2a75d704711a..2828d517573d 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -92,7 +92,7 @@ def get_grad(self) -> Dict: def get_flatten_grad(self) -> Tensor: """Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor: - [grad0_rank0, grad1_rank0, ..., grad_1_rank0, grad1_rank1, ....] + [grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....] Returns: Tensor: the flattened gradients slices in the bucket