Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions colossalai/cluster/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .device_mesh_manager import DeviceMeshManager
from .dist_coordinator import DistCoordinator
from .process_group_manager import ProcessGroupManager

__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager']
36 changes: 36 additions & 0 deletions colossalai/cluster/device_mesh_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from colossalai.device.device_mesh import DeviceMesh


class DeviceMeshManager:
"""
Device mesh manager is responsible for creating and managing device meshes.
"""

def __init__(self):
self.device_mesh_store = dict()

def create_device_mesh(self, name, *args, **kwargs) -> DeviceMesh:
"""
Create a device mesh and store it in the manager.

Args:
name (str): name of the device mesh
*args: args for DeviceMesh
**kwargs: kwargs for DeviceMesh
"""
# TODO(Yuliang): replace *args, **kwargs with explicit arguments
if name not in self.device_mesh_store:
device_mesh = DeviceMesh(*args, **kwargs)
self.device_mesh_store[name] = device_mesh
return device_mesh
else:
raise ValueError(f'Device mesh {name} already exists.')

def get(self, name: str) -> DeviceMesh:
pass

def destroy(self):
pass

def destroy_all(self):
pass
158 changes: 158 additions & 0 deletions colossalai/cluster/dist_coordinator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import os
from contextlib import contextmanager

import torch.distributed as dist
from torch.distributed import ProcessGroup

from colossalai.context.singleton_meta import SingletonMeta


class DistCoordinator(metaclass=SingletonMeta):
"""
This class is used to coordinate distributed training. It is a singleton class, which means that there is only one instance of this
class in the whole program.

There are some terms that are used in this class:
- rank: the rank of the current process
- world size: the total number of processes
- local rank: the rank of the current process on the current node
- master: the process with rank 0
- node master: the process with local rank 0 on the current node

Example:
>>> from colossalai.cluster.dist_coordinator import DistCoordinator
>>> coordinator = DistCoordinator()
>>>
>>> if coordinator.is_master():
>>> do_something()
>>>
>>> coordinator.print_on_master('hello world')

Attributes:
rank (int): the rank of the current process
world_size (int): the total number of processes
local_rank (int): the rank of the current process on the current node
"""

def __init__(self):
assert dist.is_initialized(
), 'Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first.'
self._rank = dist.get_rank()
self._world_size = dist.get_world_size()
# this is often passed by launchers such as torchrun
self._local_rank = os.environ.get('LOCAL_RANK', -1)

@property
def rank(self) -> int:
return self._rank

@property
def world_size(self) -> int:
return self._world_size

@property
def local_rank(self) -> int:
return self._local_rank

def _assert_local_rank_set(self):
"""
Assert that the local rank is set. This is often passed by launchers such as torchrun.
"""
assert self.local_rank >= 0, 'The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process.'

def is_master(self, process_group: ProcessGroup = None) -> bool:
"""
Check if the current process is the master process (rank is 0). It can accept a sub process group to check the rank 0 with respect to the process.

Args:
process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.

Returns:
bool: True if the current process is the master process, False otherwise
"""
rank = dist.get_rank(group=process_group)
return rank == 0

def is_node_master(self) -> bool:
"""
Check if the current process is the master process on the current node (local rank is 0).

Returns:
bool: True if the current process is the master process on the current node, False otherwise
"""
self._assert_local_rank_set()
return self.local_rank == 0

def is_last_process(self, process_group: ProcessGroup = None) -> bool:
"""
Check if the current process is the last process (rank is world size - 1). It can accept a sub process group to check the last rank with respect to the process.

Args:
process_group (ProcessGroup, optional): process group to use for the last rank check. Defaults to None, which refers to the default process group.

Returns:
bool: True if the current process is the last process, False otherwise
"""
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
return rank == world_size - 1

def print_on_master(self, msg: str, process_group: ProcessGroup = None):
"""
Print message only from rank 0.

Args:
msg (str): message to print
process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.
"""
rank = dist.get_rank(group=process_group)
if rank == 0:
print(msg)

def print_on_node_master(self, msg: str):
"""
Print message only from local rank 0. Local rank 0 refers to the 0th process running the current node.

Args:
msg (str): message to print
"""
self._assert_local_rank_set()
if self.local_rank == 0:
print(msg)

@contextmanager
def priority_execution(self, executor_rank: int = 0, process_group: ProcessGroup = None):
"""
This context manager is used to allow one process to execute while blocking all
other processes in the same process group. This is often useful when downloading is required
as we only want to download in one process to prevent file corruption.

Example:
>>> from colossalai.cluster import DistCoordinator
>>> dist_coordinator = DistCoordinator()
>>> with dist_coordinator.priority_execution():
>>> dataset = CIFAR10(root='./data', download=True)

Args:
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
process_group (ProcessGroup, optional): process group to use for the executor rank check. Defaults to None, which refers to the default process group.
"""
rank = dist.get_rank(group=process_group)
should_block = rank != executor_rank

if should_block:
dist.barrier(group=process_group)

yield

if not should_block:
dist.barrier(group=process_group)

def destroy(self, process_group: ProcessGroup = None):
"""
Destroy the distributed process group.

Args:
process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group.
"""
dist.destroy_process_group(process_group)
75 changes: 75 additions & 0 deletions colossalai/cluster/process_group_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import List

import torch.distributed as dist
from torch.distributed import ProcessGroup


class ProcessGroupManager:
"""
ProcessGroupManager is used to manage the process groups in the cluster.

There are some terms used in this class:
- pg: the short name for process group
- pg_name: the name of the process group
- pg_size: the world size of the process group
- rank: the rank of the current process in the process group
- world_size: the total number of processes in the process group
"""

def __init__(self):
self.pg_store = dict()

def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup:
"""
Get a process group by name. If the process group does not exist, it will be created.

Args:
name (str): name of the process group
ranks (List[int]): ranks of the process group
backend (str, optional): backend of the process group. Defaults to 'nccl'.

Returns:
ProcessGroup: the process group
"""
if name not in self.pg_store:
pg = dist.new_group(ranks=ranks, backend=backend)
self.pg_store[name] = pg
return pg
else:
raise ValueError(f'Process group {name} already exists.')

def get(self, name: str) -> ProcessGroup:
"""
Get a process group by name.

Args:
name (str): name of the process group

Returns:
ProcessGroup: the process group
"""
if name in self.pg_store:
return self.pg_store[name]
else:
raise ValueError(f'Process group {name} does not exist.')

def destroy(self, name: str) -> None:
"""
Destroy a process group by name.

Args:
name (str): name of the process group
"""
if name in self.pg_store:
dist.destroy_process_group(self.pg_store[name])
del self.pg_store[name]
else:
raise ValueError(f'Process group {name} does not exist.')

def destroy_all(self) -> None:
"""
Destroy all process groups.
"""
for name in self.pg_store:
dist.destroy_process_group(self.pg_store[name])
self.pg_store.clear()