diff --git a/colossalai/gemini/chunk/manager.py b/colossalai/gemini/chunk/manager.py index 30ac4d354647..2fa65c970316 100644 --- a/colossalai/gemini/chunk/manager.py +++ b/colossalai/gemini/chunk/manager.py @@ -72,7 +72,7 @@ def register_tensor(self, if tensor.numel() > chunk_size: chunk_size = tensor.numel() - dp_size = tensor.process_group.dp_world_size() + dp_size = tensor.get_dp_world_size() chunk_size = chunk_size + (-chunk_size % dp_size) chunk = Chunk( diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index bbed8847abbc..40eefc3ec5d1 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -138,6 +138,15 @@ def set_process_group(self, pg: ProcessGroup): def get_tp_world_size(self) -> int: return self.process_group.tp_world_size() + def get_dp_world_size(self) -> int: + """get_dp_world_size + get the dp world size of the tensor. + + Returns: + int: dp world size + """ + return self.process_group.dp_world_size() + def set_dist_spec(self, dist_spec: _DistSpec): """set_dist_spec set dist spec and change the payloads.