Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 92 additions & 11 deletions colossalai/cluster/device_mesh_manager.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,117 @@
from dataclasses import dataclass
from typing import Dict, List, Tuple, Union

import torch
import torch.distributed as dist

from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh


@dataclass
class DeviceMeshInfo:
'''
This class is used to store the information used to initialize the device mesh.

Args:
physical_ids (List[int]): The physical ids of the current booster. For example, if we have the last 4 GPUs on a 8-devices cluster, then the physical ids should be [4, 5, 6, 7].
mesh_shapes (List[Union[torch.Size, List[int], Tuple[int]]]): The shape of the mesh. For example, if we have 4 GPUs and we want to use 2D mesh with mesh shape [2, 2], then the mesh shape should be [2, 2].
'''
physical_ids: List[int]
mesh_shape: Union[torch.Size, List[int], Tuple[int]] = None

def __post_init__(self):
if self.mesh_shape is not None:
world_size = len(self.physical_ids)
mesh_shape_numel = torch.Size(self.mesh_shape).numel()
assert world_size == mesh_shape_numel, f'the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}'


def initialize_device_mesh(device_mesh_info: DeviceMeshInfo):
'''
This method is used to initialize the device mesh.

Args:
device_mesh_info (DeviceMeshInfo): The information used to initialize device mesh.
'''
# parse the device mesh info
physical_devices = device_mesh_info.physical_ids
physical_mesh = torch.tensor(physical_devices)
logical_mesh_shape = device_mesh_info.mesh_shape

if logical_mesh_shape is None:
ab_profiler = AlphaBetaProfiler(physical_devices)
# search for the best logical mesh shape
logical_mesh_id = ab_profiler.search_best_logical_mesh()
logical_mesh_id = torch.Tensor(logical_mesh_id).to(torch.int)

else:
logical_mesh_id = physical_mesh.reshape(logical_mesh_shape)

device_mesh = DeviceMesh(physical_mesh_id=physical_mesh, logical_mesh_id=logical_mesh_id, init_process_group=True)
return device_mesh


class DeviceMeshManager:
"""
Device mesh manager is responsible for creating and managing device meshes.
"""

def __init__(self):
self.device_mesh_store = dict()
self.device_mesh_store: Dict[str, DeviceMesh] = dict()

def create_device_mesh(self, name, *args, **kwargs) -> DeviceMesh:
def create_device_mesh(self, name, device_mesh_info: DeviceMeshInfo) -> DeviceMesh:
"""
Create a device mesh and store it in the manager.

Args:
name (str): name of the device mesh
*args: args for DeviceMesh
**kwargs: kwargs for DeviceMesh
"""
# TODO(Yuliang): replace *args, **kwargs with explicit arguments
device_mesh_info (DeviceMeshInfo): the information used to initialize the device mesh
"""
if name not in self.device_mesh_store:
device_mesh = DeviceMesh(*args, **kwargs)
device_mesh = initialize_device_mesh(device_mesh_info)
self.device_mesh_store[name] = device_mesh
return device_mesh
else:
raise ValueError(f'Device mesh {name} already exists.')

def get(self, name: str) -> DeviceMesh:
pass
"""
Get a device mesh by name.

def destroy(self):
pass
Args:
name (str): name of the device mesh

Returns:
DeviceMesh: the device mesh
"""
if name in self.device_mesh_store:
return self.device_mesh_store[name]
else:
raise ValueError(f'Device mesh {name} does not exist.')

def destroy(self, name: str) -> None:
"""
Destroy a device mesh by name.

Args:
name (str): name of the device mesh
"""
if name in self.device_mesh_store:
for pgs in self.device_mesh_store[name].process_groups_dict.values():
for pg in pgs:
dist.destroy_process_group(pg)
del self.device_mesh_store[name]
else:
raise ValueError(f'Device mesh {name} does not exist.')

def destroy_all(self):
pass
"""
Destroy all device meshes.
"""
for name in self.device_mesh_store:
for pgs in self.device_mesh_store[name].process_groups_dict.values():
for pg in pgs:
dist.destroy_process_group(pg)

self.device_mesh_store.clear()
40 changes: 40 additions & 0 deletions tests/test_cluster/test_device_mesh_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from functools import partial

import torch
import torch.multiprocessing as mp

from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port


def check_device_mesh_manager(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
device_mesh_manager = DeviceMeshManager()
device_mesh_info_auto = DeviceMeshInfo(physical_ids=[0, 1, 2, 3],)
device_mesh_auto = device_mesh_manager.create_device_mesh('0', device_mesh_info_auto)
assert device_mesh_auto.shape == (2, 2)
assert device_mesh_auto._logical_mesh_id.tolist() == [[0, 1], [2, 3]]

device_mesh_info_with_shape = DeviceMeshInfo(
physical_ids=[0, 1, 2, 3],
mesh_shape=(2, 2),
)
device_mesh_with_shape = device_mesh_manager.create_device_mesh('1', device_mesh_info_with_shape)

assert device_mesh_with_shape.shape == (2, 2)
assert device_mesh_with_shape._logical_mesh_id.tolist() == [[0, 1], [2, 3]]


def test_device_mesh_manager():
world_size = 4
run_func = partial(check_device_mesh_manager, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
test_device_mesh_manager()