From 47e84911697b89bc51d3d3efcd7bf44824c12e1e Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Thu, 15 Jun 2023 10:45:02 +0800 Subject: [PATCH] [device] support init device mesh from process group --- colossalai/device/device_mesh.py | 109 ++++++++++++++++++++++++-- tests/test_device/test_device_mesh.py | 69 ++++++++++++++++ 2 files changed, 171 insertions(+), 7 deletions(-) diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index 0490a440153e..d67364d9785c 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -62,11 +62,11 @@ def __init__(self, "Logical mesh IDs are obtained from either mesh_shape + phyiscal_mesh_id or directly from the user-supplied logical_mesh_id" if logical_mesh_id is None: - self.mesh_shape = mesh_shape - self._logical_mesh_id = self._physical_mesh_id.reshape(self.mesh_shape) + self._mesh_shape = mesh_shape + self._logical_mesh_id = self._physical_mesh_id.reshape(self._mesh_shape) else: self._logical_mesh_id = logical_mesh_id - self.mesh_shape = self._logical_mesh_id.shape + self._mesh_shape = self._logical_mesh_id.shape # ensure two things: # 1. logical and physical mesh IDs should contain the same elements @@ -84,9 +84,9 @@ def __init__(self, # =============================================== # if the values are not provided, we assume they are 1 for simplicity if mesh_alpha is None: - mesh_alpha = [1] * len(self.mesh_shape) + mesh_alpha = [1] * len(self._mesh_shape) if mesh_beta is None: - mesh_beta = [1] * len(self.mesh_shape) + mesh_beta = [1] * len(self._mesh_shape) self.mesh_alpha = tuple(mesh_alpha) self.mesh_beta = tuple(mesh_beta) @@ -118,6 +118,13 @@ def __init__(self, self._global_rank_of_current_process = None self._is_initialized = False + # attribute used to inidicate whether this objectd + # is created using DeviceMesh.from_process_group + # this attribute can be used to do some check in methods + # such get_process_group as no global rank information + # is known if created with from_process_group + self._is_init_from_process_group = False + # initialize process group if specified self._init_ranks_in_the_same_group() self._init_process_group = init_process_group @@ -129,7 +136,7 @@ def shape(self) -> torch.Size: """ Return the shape of the logical mesh. """ - return self.mesh_shape + return self._mesh_shape @property def num_devices(self) -> int: @@ -145,6 +152,72 @@ def logical_mesh_id(self) -> torch.Tensor: """ return self._logical_mesh_id + @property + def is_initialized(self) -> bool: + """ + Return whether the process group is initialized. + """ + return self._is_initialized + + @staticmethod + def from_process_group(process_group: Union[ProcessGroup, List[ProcessGroup]]) -> "DeviceMesh": + """ + Create a DeviceMesh instance from the current process group. Please note that the DeviceMesh object created with this method + will not have information about the physical mesh id, and thus will not be able to query for other ranks and perform alpha-beta communication. + + Args: + process_group (Union[ProcessGroup, List[ProcessGroup]]): the process group or a list of process groups for the device mesh. + If the input is a ProcessGroup object, a 1D DeviceMesh object will be created. If the input is a list of ProcessGroup objects, + the ProcessGroup at the ith index will correspond to the process group in the ith axis of the device mesh. + + Returns: + DeviceMesh: the device mesh instance. + """ + + def _get_device_by_backend(process_group): + """ + Get the device type given a process group's backend. + """ + backend = dist.get_backend(process_group) + for _device, _backend in DeviceMesh._DIST_BACKEND.items(): + if _backend == backend: + return _device + return None + + if isinstance(process_group, ProcessGroup): + process_group = [process_group] + + # get mesh shape + mesh_shape = [dist.get_world_size(pg) for pg in process_group] + + # get device + device_list = [_get_device_by_backend(pg) for pg in process_group] + + # make sure all devices are the same + assert all([device == device_list[0] for device in device_list]), \ + "All devices should be the same, please check your input process groups are created with the same distributed backend." + + # create a fake physical mesh id + # as we only get the process group associated with the current process, + # we cannot get the global ranks for all processes in the mesh + # therefore, we only use this fake physical mesh id to create the device mesh + # and will remove this fake physical mesh id later + fake_physical_mesh_id = torch.arange(reduce(operator.mul, mesh_shape, 1)) + + # create the device mesh + device_mesh = DeviceMesh(physical_mesh_id=fake_physical_mesh_id, mesh_shape=mesh_shape, device=device_list[0]) + + # hack the device attribute + device_mesh._physical_mesh_id = None + device_mesh._logical_mesh_id = None + device_mesh._global_rank_of_current_process = dist.get_rank() + device_mesh._is_initialized = False + device_mesh._process_group_dict = { + device_mesh._global_rank_of_current_process: {axis: pg for axis, pg in enumerate(process_group)} + } + + return device_mesh + def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup: """ Return the process group on the specified axis. @@ -155,6 +228,10 @@ def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup: """ if global_rank is None: global_rank = self._global_rank_of_current_process + elif self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) return self._process_group_dict[global_rank][axis] def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]: @@ -166,6 +243,10 @@ def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, P """ if global_rank is None: global_rank = self._global_rank_of_current_process + elif self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) return self._process_group_dict[global_rank] def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]: @@ -178,6 +259,10 @@ def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List """ if global_rank is None: global_rank = self._global_rank_of_current_process + elif self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) return self._ranks_in_the_process_group[global_rank][axis] def __deepcopy__(self, memo) -> "DeviceMesh": @@ -292,6 +377,11 @@ def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[i rank (int): the global rank in the logical device mesh. axis (int): the axis of the logical device mesh. """ + if self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + local_ranks = self._global_to_local_rank_mapping[rank] if axis: return local_ranks[axis] @@ -381,7 +471,12 @@ def flatten(self): """ Flatten the logical mesh into an effective 1d logical mesh, """ - flatten_mesh_shape_size = len(self.mesh_shape) + if self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + + flatten_mesh_shape_size = len(self._mesh_shape) flatten_mesh_shape = [self.num_devices] return DeviceMesh(self._physical_mesh_id, tuple(flatten_mesh_shape), diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index 19d41d23353f..590d6966bff6 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -1,6 +1,10 @@ +import pytest import torch +import torch.distributed as dist +import colossalai from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing import rerun_if_address_is_in_use, spawn def test_device_mesh(): @@ -16,5 +20,70 @@ def test_device_mesh(): assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3] +def check_1d_device_mesh(): + # check for 1D device mesh + process_group = dist.GroupMember.WORLD + device_mesh = DeviceMesh.from_process_group(process_group) + + # checks + assert device_mesh.shape == [4] + assert len(device_mesh.get_process_group_for_all_axes().keys()) == 1, 'Expected 1 axis for the process group dict' + assert device_mesh.get_process_group(axis=0) == process_group, 'Expected world process group' + assert device_mesh.is_initialized + assert device_mesh.num_devices == 4 + assert device_mesh.is_initialized + assert device_mesh.logical_mesh_id is None + assert device_mesh._is_init_from_process_group + + +def check_2d_device_mesh(): + # create process group for 2D device mesh + first_row_ranks = [0, 1] + second_row_ranks = [2, 3] + first_col_ranks = [0, 2] + second_col_ranks = [1, 3] + + first_row_pg = dist.new_group(first_row_ranks, backend='nccl') + second_row_pg = dist.new_group(second_row_ranks, backend='nccl') + first_col_pg = dist.new_group(first_col_ranks, backend='nccl') + second_col_pg = dist.new_group(second_col_ranks, backend='nccl') + + # check for + current_rank = dist.get_rank() + + if current_rank in first_row_ranks: + row_pg = first_row_pg + else: + row_pg = second_row_pg + + if current_rank in first_col_ranks: + col_pg = first_col_pg + else: + col_pg = second_col_pg + + device_mesh = DeviceMesh.from_process_group([col_pg, row_pg]) + + # checks + assert device_mesh.shape == [2, 2] + assert len(device_mesh.get_process_group_for_all_axes().keys()) == 2, 'Expected 2 axes for the process group dict' + assert device_mesh.get_process_group(axis=0) == col_pg, 'Expected column process group' + assert device_mesh.get_process_group(axis=1) == row_pg, 'Expected row process group' + assert device_mesh.num_devices == 4 + assert device_mesh.is_initialized + assert device_mesh.logical_mesh_id is None + assert device_mesh._is_init_from_process_group + + +def check_init_from_process_group(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_device_mesh_from_process_group(): + spawn(check_init_from_process_group, 4) + + if __name__ == '__main__': test_device_mesh() + test_device_mesh_from_process_group()