Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions colossalai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
from . import accelerator

try:
# .version will be created by setup.py
Expand Down
20 changes: 20 additions & 0 deletions colossalai/accelerator/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# 🚀 Accelerator

## 🔗 Table of Contents

- [🚀 Accelerator](#-accelerator)
- [🔗 Table of Contents](#-table-of-contents)
- [📚 Introduction](#-introduction)
- [📌 Design and Acknowledgement](#-design-and-acknowledgement)

## 📚 Introduction

This module offers a layer of abstraction for ColossalAI. With this module, the user can easily switch between different accelerator backends, such as Nvidia GPUs, Huawei NPUs, etc. This module is an attempt to make users' code portable across different hardware platform with a simple `auto_set_accelerator()` API.

## 📌 Design and Acknowledgement

Our `accelerator` module is heavily inspired by [`deepspeed/accelerator`](https://www.deepspeed.ai/tutorials/accelerator-abstraction-interface/). We found that it is a very well-designed and well-structured module that can be easily integrated into our project. We would like to thank the DeepSpeed team for their great work.

We implemented this accelerator module from scratch. At the same time, we have implemented our own modifications:
1. we updated the accelerator API names to be aligned with PyTorch's native API names.
2. we did not include the `op builder` in the `accelerator`. Instead, we have reconstructed our `kernel` module to automatically match the accelerator and its corresponding kernel implementations, so as to make modules less tangled.
13 changes: 13 additions & 0 deletions colossalai/accelerator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .api import auto_set_accelerator, get_accelerator, set_accelerator
from .base_accelerator import BaseAccelerator
from .cuda_accelerator import CudaAccelerator
from .npu_accelerator import NpuAccelerator

__all__ = [
"get_accelerator",
"set_accelerator",
"auto_set_accelerator",
"BaseAccelerator",
"CudaAccelerator",
"NpuAccelerator",
]
72 changes: 72 additions & 0 deletions colossalai/accelerator/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/usr/bin/env python
from collections import OrderedDict
from typing import Union

from .base_accelerator import BaseAccelerator
from .cuda_accelerator import CudaAccelerator
from .npu_accelerator import NpuAccelerator

__all__ = ["set_accelerator", "auto_set_accelerator", "get_accelerator"]


_ACCELERATOR = None


# we use ordered dictionary here to associate the
# order with device check priority
# i.e. auto_set_accelerator will check cuda first
_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator)


def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None:
Comment thread
FrankLeeeee marked this conversation as resolved.
"""
Set the global accelerator for the current process.

Args:
accelerator (Union[str, BaseAccelerator]): the type of accelerator to which the current device belongs.
"""

global _ACCELERATOR

if isinstance(accelerator, str):
_ACCELERATOR = _ACCELERATOR_MAPPING[accelerator]()
elif isinstance(accelerator, BaseAccelerator):
_ACCELERATOR = accelerator
else:
raise TypeError("accelerator must be either a string or an instance of BaseAccelerator")


def auto_set_accelerator() -> None:
"""
Automatically check if any accelerator is available.
If an accelerator is availabe, set it as the global accelerator.
"""
global _ACCELERATOR

for _, accelerator_cls in _ACCELERATOR_MAPPING.items():
try:
accelerator = accelerator_cls()
if accelerator.is_available():
_ACCELERATOR = accelerator
break
except:
pass

if _ACCELERATOR is None:
raise RuntimeError(
f"No accelerator is available. Please check your environment. The list of accelerators we support is {list(_ACCELERATOR_MAPPING.keys())}"
)


def get_accelerator() -> BaseAccelerator:
"""
Return the accelerator for the current process. If the accelerator is not initialized, it will be initialized
to the default accelerator type.

Returns: the accelerator for the current process.
"""
global _ACCELERATOR

if _ACCELERATOR is None:
auto_set_accelerator()
return _ACCELERATOR
81 changes: 81 additions & 0 deletions colossalai/accelerator/base_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env python
from abc import ABC, abstractmethod
from typing import Union

import torch

__all__ = ["BaseAccelerator"]


class BaseAccelerator(ABC):
def __init__(self, name: str, communication_backend: str, is_synchronous: bool) -> None:
self._name = name
self._communication_backend = communication_backend
self._is_synchronous = is_synchronous

# =======================
# immutable attributes
# =======================

@property
def name(self) -> str:
"""
Return the name of the accelerator.
"""
return self._name

@property
def communication_backend(self) -> str:
"""
Return the name of the backend communication library.
"""
return self._communication_backend

@property
def is_synchronous(self) -> bool:
"""
Return whether the accelerator is a synchronous device.
"""
return self._is_synchronous

def __repr__(self) -> str:
cls_name = self.__class__.__name__
return f"{cls_name}(name={self._name}, communication_backend={self._communication_backend}, is_synchronous={self._is_synchronous})"

# =======================
# device APIs
# =======================
@abstractmethod
def current_device(self) -> int:
"""
Return the current device index.
"""

@abstractmethod
def set_device(self, device: Union[torch.device, int]) -> None:
"""
Bind the current process to a device.
"""

def get_device_name(self, device: Union[torch.device, int]) -> str:
"""
Return the name of the device.
"""

@abstractmethod
def synchronize(self, device: Union[torch.device, int] = None):
"""
Synchronize the current process.
"""

@abstractmethod
def is_available(self):
"""
Check if the accelerator is available.
"""

@abstractmethod
def device_count(self):
"""
Return the number of devices on the machine.
"""
56 changes: 56 additions & 0 deletions colossalai/accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/usr/bin/env python
from typing import Union

import torch

from .base_accelerator import BaseAccelerator

__all__ = ["CudaAccelerator"]


class CudaAccelerator(BaseAccelerator):
"""
Accelerator class for Nvidia CUDA devices.
"""

def __init__(self):
super().__init__(name="cuda", communication_backend="nccl", is_synchronous=False)

# =======================
# device APIs
# =======================
def current_device(self) -> int:
"""
Return the current device index.
"""
return torch.cuda.current_device()

def set_device(self, device: Union[torch.device, int]) -> None:
"""
Bind the current process to a device.
"""
torch.cuda.set_device(device)

def get_device_name(self, device: Union[torch.device, int]) -> str:
"""
Return the name of the device.
"""
return torch.cuda.get_device_name(device)

def synchronize(self, device: Union[torch.device, int] = None):
"""
Synchronize the current process.
"""
torch.cuda.synchronize(device)

def is_available(self):
"""
Check if the accelerator is available.
"""
return torch.cuda.is_available()

def device_count(self):
"""
Return the number of devices on the machine.
"""
return torch.cuda.device_count()
63 changes: 63 additions & 0 deletions colossalai/accelerator/npu_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python

from typing import Union

import torch

from .base_accelerator import BaseAccelerator

try:
import torch_npu # noqa
except ImportError:
pass


__all__ = ["NpuAccelerator"]


class NpuAccelerator(BaseAccelerator):
"""
Accelerator class for Huawei NPU devices.
"""

def __init__(self):
super().__init__(name="npu", communication_backend="hccl", is_synchronous=False)

# =======================
# device APIs
# =======================
def current_device(self) -> int:
"""
Return the current device index.
"""
return torch.npu.current_device()

def set_device(self, device: Union[torch.device, int]) -> None:
"""
Bind the current process to a device.
"""
torch.npu.set_device(device)

def get_device_name(self, device: Union[torch.device, int]) -> str:
"""
Return the name of the device.
"""
return torch.npu.get_device_name(device)

def synchronize(self, device: Union[torch.device, int] = None):
"""
Synchronize the current process.
"""
torch.npu.synchronize(device)

def is_available(self):
"""
Check if the accelerator is available.
"""
return torch.npu.is_available()

def device_count(self):
"""
Return the number of devices on the machine.
"""
return torch.npu.device_count()