diff --git a/colossalai/cluster/device_mesh_manager.py b/colossalai/cluster/device_mesh_manager.py index 744799182e22..8754baa19792 100644 --- a/colossalai/cluster/device_mesh_manager.py +++ b/colossalai/cluster/device_mesh_manager.py @@ -1,36 +1,117 @@ +from dataclasses import dataclass +from typing import Dict, List, Tuple, Union + +import torch +import torch.distributed as dist + +from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler from colossalai.device.device_mesh import DeviceMesh +@dataclass +class DeviceMeshInfo: + ''' + This class is used to store the information used to initialize the device mesh. + + Args: + physical_ids (List[int]): The physical ids of the current booster. For example, if we have the last 4 GPUs on a 8-devices cluster, then the physical ids should be [4, 5, 6, 7]. + mesh_shapes (List[Union[torch.Size, List[int], Tuple[int]]]): The shape of the mesh. For example, if we have 4 GPUs and we want to use 2D mesh with mesh shape [2, 2], then the mesh shape should be [2, 2]. + ''' + physical_ids: List[int] + mesh_shape: Union[torch.Size, List[int], Tuple[int]] = None + + def __post_init__(self): + if self.mesh_shape is not None: + world_size = len(self.physical_ids) + mesh_shape_numel = torch.Size(self.mesh_shape).numel() + assert world_size == mesh_shape_numel, f'the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}' + + +def initialize_device_mesh(device_mesh_info: DeviceMeshInfo): + ''' + This method is used to initialize the device mesh. + + Args: + device_mesh_info (DeviceMeshInfo): The information used to initialize device mesh. + ''' + # parse the device mesh info + physical_devices = device_mesh_info.physical_ids + physical_mesh = torch.tensor(physical_devices) + logical_mesh_shape = device_mesh_info.mesh_shape + + if logical_mesh_shape is None: + ab_profiler = AlphaBetaProfiler(physical_devices) + # search for the best logical mesh shape + logical_mesh_id = ab_profiler.search_best_logical_mesh() + logical_mesh_id = torch.Tensor(logical_mesh_id).to(torch.int) + + else: + logical_mesh_id = physical_mesh.reshape(logical_mesh_shape) + + device_mesh = DeviceMesh(physical_mesh_id=physical_mesh, logical_mesh_id=logical_mesh_id, init_process_group=True) + return device_mesh + + class DeviceMeshManager: """ Device mesh manager is responsible for creating and managing device meshes. """ def __init__(self): - self.device_mesh_store = dict() + self.device_mesh_store: Dict[str, DeviceMesh] = dict() - def create_device_mesh(self, name, *args, **kwargs) -> DeviceMesh: + def create_device_mesh(self, name, device_mesh_info: DeviceMeshInfo) -> DeviceMesh: """ Create a device mesh and store it in the manager. Args: name (str): name of the device mesh - *args: args for DeviceMesh - **kwargs: kwargs for DeviceMesh - """ - # TODO(Yuliang): replace *args, **kwargs with explicit arguments + device_mesh_info (DeviceMeshInfo): the information used to initialize the device mesh + """ if name not in self.device_mesh_store: - device_mesh = DeviceMesh(*args, **kwargs) + device_mesh = initialize_device_mesh(device_mesh_info) self.device_mesh_store[name] = device_mesh return device_mesh else: raise ValueError(f'Device mesh {name} already exists.') def get(self, name: str) -> DeviceMesh: - pass + """ + Get a device mesh by name. - def destroy(self): - pass + Args: + name (str): name of the device mesh + + Returns: + DeviceMesh: the device mesh + """ + if name in self.device_mesh_store: + return self.device_mesh_store[name] + else: + raise ValueError(f'Device mesh {name} does not exist.') + + def destroy(self, name: str) -> None: + """ + Destroy a device mesh by name. + + Args: + name (str): name of the device mesh + """ + if name in self.device_mesh_store: + for pgs in self.device_mesh_store[name].process_groups_dict.values(): + for pg in pgs: + dist.destroy_process_group(pg) + del self.device_mesh_store[name] + else: + raise ValueError(f'Device mesh {name} does not exist.') def destroy_all(self): - pass + """ + Destroy all device meshes. + """ + for name in self.device_mesh_store: + for pgs in self.device_mesh_store[name].process_groups_dict.values(): + for pg in pgs: + dist.destroy_process_group(pg) + + self.device_mesh_store.clear() diff --git a/tests/test_cluster/test_device_mesh_manager.py b/tests/test_cluster/test_device_mesh_manager.py new file mode 100644 index 000000000000..b79814735325 --- /dev/null +++ b/tests/test_cluster/test_device_mesh_manager.py @@ -0,0 +1,40 @@ +from functools import partial + +import torch +import torch.multiprocessing as mp + +from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer import ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.utils import free_port + + +def check_device_mesh_manager(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + device_mesh_manager = DeviceMeshManager() + device_mesh_info_auto = DeviceMeshInfo(physical_ids=[0, 1, 2, 3],) + device_mesh_auto = device_mesh_manager.create_device_mesh('0', device_mesh_info_auto) + assert device_mesh_auto.shape == (2, 2) + assert device_mesh_auto._logical_mesh_id.tolist() == [[0, 1], [2, 3]] + + device_mesh_info_with_shape = DeviceMeshInfo( + physical_ids=[0, 1, 2, 3], + mesh_shape=(2, 2), + ) + device_mesh_with_shape = device_mesh_manager.create_device_mesh('1', device_mesh_info_with_shape) + + assert device_mesh_with_shape.shape == (2, 2) + assert device_mesh_with_shape._logical_mesh_id.tolist() == [[0, 1], [2, 3]] + + +def test_device_mesh_manager(): + world_size = 4 + run_func = partial(check_device_mesh_manager, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_device_mesh_manager()