From 0d76d7cf22696fc9466e294e140690f87954cca9 Mon Sep 17 00:00:00 2001 From: Younghwan Na <100389977+yhna940@users.noreply.github.com> Date: Fri, 24 Mar 2023 10:35:41 +0900 Subject: [PATCH] Add interface for colo tesnor dp size --- colossalai/gemini/chunk/manager.py | 2 +- colossalai/tensor/colo_tensor.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/colossalai/gemini/chunk/manager.py b/colossalai/gemini/chunk/manager.py index 30ac4d354647..2fa65c970316 100644 --- a/colossalai/gemini/chunk/manager.py +++ b/colossalai/gemini/chunk/manager.py @@ -72,7 +72,7 @@ def register_tensor(self, if tensor.numel() > chunk_size: chunk_size = tensor.numel() - dp_size = tensor.process_group.dp_world_size() + dp_size = tensor.get_dp_world_size() chunk_size = chunk_size + (-chunk_size % dp_size) chunk = Chunk( diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index bbed8847abbc..40eefc3ec5d1 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -138,6 +138,15 @@ def set_process_group(self, pg: ProcessGroup): def get_tp_world_size(self) -> int: return self.process_group.tp_world_size() + def get_dp_world_size(self) -> int: + """get_dp_world_size + get the dp world size of the tensor. + + Returns: + int: dp world size + """ + return self.process_group.dp_world_size() + def set_dist_spec(self, dist_spec: _DistSpec): """set_dist_spec set dist spec and change the payloads.