From 1ecbfff6a263ee391d1265ad0f4dcacfabcc0e05 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 Date: Thu, 23 Mar 2023 18:07:20 +0800 Subject: [PATCH 1/2] [API] implement device mesh manager --- colossalai/cluster/device_mesh_manager.py | 103 ++++++++++++++++-- .../test_cluster/test_device_mesh_manager.py | 40 +++++++ 2 files changed, 132 insertions(+), 11 deletions(-) create mode 100644 tests/test_cluster/test_device_mesh_manager.py diff --git a/colossalai/cluster/device_mesh_manager.py b/colossalai/cluster/device_mesh_manager.py index 744799182e22..0360a5ec02a4 100644 --- a/colossalai/cluster/device_mesh_manager.py +++ b/colossalai/cluster/device_mesh_manager.py @@ -1,36 +1,117 @@ +from dataclasses import dataclass +from typing import Dict, List, Tuple, Union + +import torch +import torch.distributed as dist + +from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler from colossalai.device.device_mesh import DeviceMesh +@dataclass +class DeviceMeshInfo: + ''' + This class is used to store the information used to initialize the device mesh. + + Args: + physical_ids (List[int]): The physical ids of the current booster. For example, if we have the last 4 GPUs on a 8-devices cluster, then the physical ids should be [4, 5, 6, 7]. + mesh_shapes (List[Union[torch.Size, List[int], Tuple[int]]]): The shape of the mesh. For example, if we have 4 GPUs and we want to use 2D mesh with mesh shape [2, 2], then the mesh shape should be [2, 2]. + ''' + physical_ids: List[int] + mesh_shape: Union[torch.Size, List[int], Tuple[int]] = None + + def __post_init__(self): + if self.mesh_shape is not None: + world_size = len(self.physical_ids) + mesh_shape_numel = torch.Size(self.mesh_shape).numel() + assert world_size == mesh_shape_numel, f'the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}' + + +def initialize_device_mesh(device_mesh_info: DeviceMeshInfo): + ''' + This method is used to initialize the device mesh. + + Args: + device_mesh_info (DeviceMeshInfo): The information used to initialize device mesh. + ''' + # parse the device mesh info + physical_devices = device_mesh_info.physical_ids + physical_mesh = torch.tensor(physical_devices) + logical_mesh_shape = device_mesh_info.mesh_shape + + if logical_mesh_shape is None: + ab_profiler = AlphaBetaProfiler(physical_devices) + # search for the best logical mesh shape + logical_mesh_id = ab_profiler.search_best_logical_mesh() + logical_mesh_id = torch.Tensor(logical_mesh_id).to(torch.int) + + else: + logical_mesh_id = physical_mesh.reshape(logical_mesh_shape) + + device_mesh = DeviceMesh(physical_mesh_id=physical_mesh, logical_mesh_id=logical_mesh_id, init_process_group=True) + return device_mesh + + class DeviceMeshManager: """ Device mesh manager is responsible for creating and managing device meshes. """ def __init__(self): - self.device_mesh_store = dict() + self.device_mesh_store: Dict[str, DeviceMesh] = dict() - def create_device_mesh(self, name, *args, **kwargs) -> DeviceMesh: + def create_device_mesh(self, name, device_mesh_info: DeviceMeshInfo) -> DeviceMesh: """ Create a device mesh and store it in the manager. Args: name (str): name of the device mesh - *args: args for DeviceMesh - **kwargs: kwargs for DeviceMesh - """ - # TODO(Yuliang): replace *args, **kwargs with explicit arguments + device_mesh_info (DeviceMeshInfo): the information used to initialize the device mesh + """ if name not in self.device_mesh_store: - device_mesh = DeviceMesh(*args, **kwargs) + device_mesh = initialize_device_mesh(device_mesh_info) self.device_mesh_store[name] = device_mesh return device_mesh else: raise ValueError(f'Device mesh {name} already exists.') def get(self, name: str) -> DeviceMesh: - pass + """ + Get a device mesh by name. - def destroy(self): - pass + Args: + name (str): name of the device mesh + + Returns: + DeviceMesh: the device mesh + """ + if name in self.device_mesh_store: + return self.device_mesh_store[name] + else: + raise ValueError(f'Device mesh {name} does not exist.') + + def destroy(self, name: str) -> None: + """ + Destroy a device mesh by name. + + Args: + name (str): name of the device mesh + """ + if name in self.device_mesh_store: + for pgs in self.device_mesh_store[name].process_groups_dict.values(): + for pg in pgs: + dist.destroy_process_group(pg) + del self.device_mesh_store[name] + else: + raise ValueError(f'Process group {name} does not exist.') def destroy_all(self): - pass + """ + Destroy all device meshes. + """ + for name in self.device_mesh_store: + for pgs in self.device_mesh_store[name].process_groups_dict.values(): + for pg in pgs: + dist.destroy_process_group(pg) + + self.device_mesh_store.clear() diff --git a/tests/test_cluster/test_device_mesh_manager.py b/tests/test_cluster/test_device_mesh_manager.py new file mode 100644 index 000000000000..b79814735325 --- /dev/null +++ b/tests/test_cluster/test_device_mesh_manager.py @@ -0,0 +1,40 @@ +from functools import partial + +import torch +import torch.multiprocessing as mp + +from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.tracer import ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.utils import free_port + + +def check_device_mesh_manager(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + device_mesh_manager = DeviceMeshManager() + device_mesh_info_auto = DeviceMeshInfo(physical_ids=[0, 1, 2, 3],) + device_mesh_auto = device_mesh_manager.create_device_mesh('0', device_mesh_info_auto) + assert device_mesh_auto.shape == (2, 2) + assert device_mesh_auto._logical_mesh_id.tolist() == [[0, 1], [2, 3]] + + device_mesh_info_with_shape = DeviceMeshInfo( + physical_ids=[0, 1, 2, 3], + mesh_shape=(2, 2), + ) + device_mesh_with_shape = device_mesh_manager.create_device_mesh('1', device_mesh_info_with_shape) + + assert device_mesh_with_shape.shape == (2, 2) + assert device_mesh_with_shape._logical_mesh_id.tolist() == [[0, 1], [2, 3]] + + +def test_device_mesh_manager(): + world_size = 4 + run_func = partial(check_device_mesh_manager, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_device_mesh_manager() From ae5836324b62aa807ed55009a91dffb5a4527056 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 Date: Fri, 24 Mar 2023 11:18:53 +0800 Subject: [PATCH 2/2] polish --- colossalai/cluster/device_mesh_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/cluster/device_mesh_manager.py b/colossalai/cluster/device_mesh_manager.py index 0360a5ec02a4..8754baa19792 100644 --- a/colossalai/cluster/device_mesh_manager.py +++ b/colossalai/cluster/device_mesh_manager.py @@ -103,7 +103,7 @@ def destroy(self, name: str) -> None: dist.destroy_process_group(pg) del self.device_mesh_store[name] else: - raise ValueError(f'Process group {name} does not exist.') + raise ValueError(f'Device mesh {name} does not exist.') def destroy_all(self): """