diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index e7f199129a00..eb03500531bc 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -50,7 +50,7 @@ def example_replace_weight(module: torch.nn.Module): new_weight = shard_rowwise(weight, process_group) module.weight = torch.nn.Parameter(new_weight) ``` - sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription + sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a SubModuleReplacementDescription object which specifies the module to be replaced and the target module used to replacement. method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement """ diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index 2a75d704711a..2828d517573d 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -92,7 +92,7 @@ def get_grad(self) -> Dict: def get_flatten_grad(self) -> Tensor: """Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor: - [grad0_rank0, grad1_rank0, ..., grad_1_rank0, grad1_rank1, ....] + [grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....] Returns: Tensor: the flattened gradients slices in the bucket