diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index e20d846f1071..3863394b1705 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -425,8 +425,8 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: param = self.id_to_real_params[param_id] fake_param = self.id_to_fake_params.get(param_id, None) chunk = self.chunk_manager.get_chunk(param) - dp_group = chunk.torch_pg - rank = dist.get_rank(dp_group) + zero_group = chunk.torch_pg + rank = dist.get_rank(zero_group) master_rank = 0 collected_states = {} @@ -434,9 +434,9 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: local_state_names = None if fake_param is not None: local_state_names = list(self.optim.state[fake_param].keys()) - gathered_state_names = [None for _ in range(dist.get_world_size(dp_group))] + gathered_state_names = [None for _ in range(dist.get_world_size(zero_group))] dist.barrier() - dist.all_gather_object(gathered_state_names, local_state_names, dp_group) + dist.all_gather_object(gathered_state_names, local_state_names, zero_group) state_names = None for names in gathered_state_names: if names is not None: @@ -510,10 +510,10 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: _, shard_offset, shard_size = self.get_offsets(param_id) # Collectors gather state shards through all_gathering. - gathered_state_shards = [None for _ in range(dist.get_world_size(dp_group))] + gathered_state_shards = [None for _ in range(dist.get_world_size(zero_group))] dist.barrier() - dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size]) + dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size], group=zero_group) if is_collector: for state_shard in gathered_state_shards: