Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 102 additions & 7 deletions colossalai/device/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def __init__(self,
"Logical mesh IDs are obtained from either mesh_shape + phyiscal_mesh_id or directly from the user-supplied logical_mesh_id"

if logical_mesh_id is None:
self.mesh_shape = mesh_shape
self._logical_mesh_id = self._physical_mesh_id.reshape(self.mesh_shape)
self._mesh_shape = mesh_shape
self._logical_mesh_id = self._physical_mesh_id.reshape(self._mesh_shape)
else:
self._logical_mesh_id = logical_mesh_id
self.mesh_shape = self._logical_mesh_id.shape
self._mesh_shape = self._logical_mesh_id.shape

# ensure two things:
# 1. logical and physical mesh IDs should contain the same elements
Expand All @@ -84,9 +84,9 @@ def __init__(self,
# ===============================================
# if the values are not provided, we assume they are 1 for simplicity
if mesh_alpha is None:
mesh_alpha = [1] * len(self.mesh_shape)
mesh_alpha = [1] * len(self._mesh_shape)
if mesh_beta is None:
mesh_beta = [1] * len(self.mesh_shape)
mesh_beta = [1] * len(self._mesh_shape)

self.mesh_alpha = tuple(mesh_alpha)
self.mesh_beta = tuple(mesh_beta)
Expand Down Expand Up @@ -118,6 +118,13 @@ def __init__(self,
self._global_rank_of_current_process = None
self._is_initialized = False

# attribute used to inidicate whether this objectd
# is created using DeviceMesh.from_process_group
# this attribute can be used to do some check in methods
# such get_process_group as no global rank information
# is known if created with from_process_group
self._is_init_from_process_group = False

# initialize process group if specified
self._init_ranks_in_the_same_group()
self._init_process_group = init_process_group
Expand All @@ -129,7 +136,7 @@ def shape(self) -> torch.Size:
"""
Return the shape of the logical mesh.
"""
return self.mesh_shape
return self._mesh_shape

@property
def num_devices(self) -> int:
Expand All @@ -145,6 +152,72 @@ def logical_mesh_id(self) -> torch.Tensor:
"""
return self._logical_mesh_id

@property
def is_initialized(self) -> bool:
"""
Return whether the process group is initialized.
"""
return self._is_initialized

@staticmethod
def from_process_group(process_group: Union[ProcessGroup, List[ProcessGroup]]) -> "DeviceMesh":
"""
Create a DeviceMesh instance from the current process group. Please note that the DeviceMesh object created with this method
will not have information about the physical mesh id, and thus will not be able to query for other ranks and perform alpha-beta communication.

Args:
process_group (Union[ProcessGroup, List[ProcessGroup]]): the process group or a list of process groups for the device mesh.
If the input is a ProcessGroup object, a 1D DeviceMesh object will be created. If the input is a list of ProcessGroup objects,
the ProcessGroup at the ith index will correspond to the process group in the ith axis of the device mesh.

Returns:
DeviceMesh: the device mesh instance.
"""

def _get_device_by_backend(process_group):
"""
Get the device type given a process group's backend.
"""
backend = dist.get_backend(process_group)
for _device, _backend in DeviceMesh._DIST_BACKEND.items():
if _backend == backend:
return _device
return None

if isinstance(process_group, ProcessGroup):
process_group = [process_group]

# get mesh shape
mesh_shape = [dist.get_world_size(pg) for pg in process_group]

# get device
device_list = [_get_device_by_backend(pg) for pg in process_group]

# make sure all devices are the same
assert all([device == device_list[0] for device in device_list]), \
"All devices should be the same, please check your input process groups are created with the same distributed backend."

# create a fake physical mesh id
# as we only get the process group associated with the current process,
# we cannot get the global ranks for all processes in the mesh
# therefore, we only use this fake physical mesh id to create the device mesh
# and will remove this fake physical mesh id later
fake_physical_mesh_id = torch.arange(reduce(operator.mul, mesh_shape, 1))

# create the device mesh
device_mesh = DeviceMesh(physical_mesh_id=fake_physical_mesh_id, mesh_shape=mesh_shape, device=device_list[0])

# hack the device attribute
device_mesh._physical_mesh_id = None
device_mesh._logical_mesh_id = None
device_mesh._global_rank_of_current_process = dist.get_rank()
device_mesh._is_initialized = False
device_mesh._process_group_dict = {
device_mesh._global_rank_of_current_process: {axis: pg for axis, pg in enumerate(process_group)}
}

return device_mesh

def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup:
"""
Return the process group on the specified axis.
Expand All @@ -155,6 +228,10 @@ def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup:
"""
if global_rank is None:
global_rank = self._global_rank_of_current_process
elif self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)
return self._process_group_dict[global_rank][axis]

def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]:
Expand All @@ -166,6 +243,10 @@ def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, P
"""
if global_rank is None:
global_rank = self._global_rank_of_current_process
elif self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)
return self._process_group_dict[global_rank]

def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]:
Expand All @@ -178,6 +259,10 @@ def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List
"""
if global_rank is None:
global_rank = self._global_rank_of_current_process
elif self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)
return self._ranks_in_the_process_group[global_rank][axis]

def __deepcopy__(self, memo) -> "DeviceMesh":
Expand Down Expand Up @@ -292,6 +377,11 @@ def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[i
rank (int): the global rank in the logical device mesh.
axis (int): the axis of the logical device mesh.
"""
if self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)

local_ranks = self._global_to_local_rank_mapping[rank]
if axis:
return local_ranks[axis]
Expand Down Expand Up @@ -381,7 +471,12 @@ def flatten(self):
"""
Flatten the logical mesh into an effective 1d logical mesh,
"""
flatten_mesh_shape_size = len(self.mesh_shape)
if self._is_init_from_process_group:
raise RuntimeError(
"The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
)

flatten_mesh_shape_size = len(self._mesh_shape)
flatten_mesh_shape = [self.num_devices]
return DeviceMesh(self._physical_mesh_id,
tuple(flatten_mesh_shape),
Expand Down
69 changes: 69 additions & 0 deletions tests/test_device/test_device_mesh.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import pytest
import torch
import torch.distributed as dist

import colossalai
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import rerun_if_address_is_in_use, spawn


def test_device_mesh():
Expand All @@ -16,5 +20,70 @@ def test_device_mesh():
assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3]


def check_1d_device_mesh():
# check for 1D device mesh
process_group = dist.GroupMember.WORLD
device_mesh = DeviceMesh.from_process_group(process_group)

# checks
assert device_mesh.shape == [4]
assert len(device_mesh.get_process_group_for_all_axes().keys()) == 1, 'Expected 1 axis for the process group dict'
assert device_mesh.get_process_group(axis=0) == process_group, 'Expected world process group'
assert device_mesh.is_initialized
assert device_mesh.num_devices == 4
assert device_mesh.is_initialized
assert device_mesh.logical_mesh_id is None
assert device_mesh._is_init_from_process_group


def check_2d_device_mesh():
# create process group for 2D device mesh
first_row_ranks = [0, 1]
second_row_ranks = [2, 3]
first_col_ranks = [0, 2]
second_col_ranks = [1, 3]

first_row_pg = dist.new_group(first_row_ranks, backend='nccl')
second_row_pg = dist.new_group(second_row_ranks, backend='nccl')
first_col_pg = dist.new_group(first_col_ranks, backend='nccl')
second_col_pg = dist.new_group(second_col_ranks, backend='nccl')

# check for
current_rank = dist.get_rank()

if current_rank in first_row_ranks:
row_pg = first_row_pg
else:
row_pg = second_row_pg

if current_rank in first_col_ranks:
col_pg = first_col_pg
else:
col_pg = second_col_pg

device_mesh = DeviceMesh.from_process_group([col_pg, row_pg])

# checks
assert device_mesh.shape == [2, 2]
assert len(device_mesh.get_process_group_for_all_axes().keys()) == 2, 'Expected 2 axes for the process group dict'
assert device_mesh.get_process_group(axis=0) == col_pg, 'Expected column process group'
assert device_mesh.get_process_group(axis=1) == row_pg, 'Expected row process group'
assert device_mesh.num_devices == 4
assert device_mesh.is_initialized
assert device_mesh.logical_mesh_id is None
assert device_mesh._is_init_from_process_group


def check_init_from_process_group(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_device_mesh_from_process_group():
spawn(check_init_from_process_group, 4)


if __name__ == '__main__':
test_device_mesh()
test_device_mesh_from_process_group()