From f23d23245e39f3de2f3c2fbca603073204686a8c Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Tue, 28 Nov 2023 17:47:53 +0800 Subject: [PATCH 1/5] [accelerator] init the accelerator module --- colossalai/__init__.py | 1 + colossalai/accelerator/README.md | 20 +++++ colossalai/accelerator/__init__.py | 12 +++ colossalai/accelerator/api.py | 51 ++++++++++++ colossalai/accelerator/base_accelerator.py | 93 ++++++++++++++++++++++ colossalai/accelerator/cuda_accelerator.py | 55 +++++++++++++ colossalai/accelerator/npu_accelerator.py | 61 ++++++++++++++ 7 files changed, 293 insertions(+) create mode 100644 colossalai/accelerator/README.md create mode 100644 colossalai/accelerator/__init__.py create mode 100644 colossalai/accelerator/api.py create mode 100644 colossalai/accelerator/base_accelerator.py create mode 100644 colossalai/accelerator/cuda_accelerator.py create mode 100644 colossalai/accelerator/npu_accelerator.py diff --git a/colossalai/__init__.py b/colossalai/__init__.py index 7da55590305b..6b7f5d055207 100644 --- a/colossalai/__init__.py +++ b/colossalai/__init__.py @@ -1,4 +1,5 @@ from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch +from . import accelerator try: # .version will be created by setup.py diff --git a/colossalai/accelerator/README.md b/colossalai/accelerator/README.md new file mode 100644 index 000000000000..783ee8d7b630 --- /dev/null +++ b/colossalai/accelerator/README.md @@ -0,0 +1,20 @@ +# 🚀 Accelerator + +## 🔗 Table of Contents + +- [🚀 Accelerator](#-accelerator) + - [🔗 Table of Contents](#-table-of-contents) + - [📚 Introduction](#-introduction) + - [📌 Design and Acknowledgement](#-design-and-acknowledgement) + +## 📚 Introduction + +This module offers a layer of abstraction for ColossalAI. With this module, the user can easily switch between different accelerator backends, such as Nvidia GPUs, Huawei NPUs, etc. This module is an attempt to make users' code portable across different hardware platform with a simple `set_accelerator()` API. + +## 📌 Design and Acknowledgement + +Our `accelerator` module is heavily inspired by [`deepspeed/accelerator`](https://www.deepspeed.ai/tutorials/accelerator-abstraction-interface/). We found that it is a very well-designed and well-structured module that can be easily integrated into our project. We would like to thank the DeepSpeed team for their great work. + +At the same time, we have implemented our own modifications:1 +1. we updated the accelerator API names to be aligned with PyTorch's native API names. +2. we did not include the `op builder` in the `accelerator`. Instead, we have reconstructed our `kernel` module to automatically match the accelerator and its corresponding kernel implementations, so as to make modules less tangled. \ No newline at end of file diff --git a/colossalai/accelerator/__init__.py b/colossalai/accelerator/__init__.py new file mode 100644 index 000000000000..b870ff5094ed --- /dev/null +++ b/colossalai/accelerator/__init__.py @@ -0,0 +1,12 @@ +from .api import get_accelerator, set_accelerator +from .base_accelerator import BaseAccelerator +from .cuda_accelerator import CudaAccelerator +from .npu_accelerator import NpuAccelerator + +__all__ = [ + 'get_accelerator', + 'set_accelerator', + 'BaseAccelerator', + 'CudaAccelerator', + 'NpuAccelerator' +] \ No newline at end of file diff --git a/colossalai/accelerator/api.py b/colossalai/accelerator/api.py new file mode 100644 index 000000000000..64c398d043c3 --- /dev/null +++ b/colossalai/accelerator/api.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +from typing import Union +from .base_accelerator import BaseAccelerator +from .cuda_accelerator import CudaAccelerator +from .npu_accelerator import NpuAccelerator + +__all__ = ['set_accelerator', 'get_accelerator'] + + +_ACCELERATOR = None + +_ACCELERATOR_MAPPING = { + "cuda": CudaAccelerator, + "npu": NpuAccelerator, +} + +_DEFAULT_ACCELERATOR_TYPE = CudaAccelerator + +def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None: + """ + Set the global accelerator for the current process. + + Args: + accelerator (Union[str, BaseAccelerator]): the type of accelerator to which the current device belongs. + """ + + global _ACCELERATOR + + if isinstance(accelerator, str): + _ACCELERATOR = _ACCELERATOR_MAPPING[accelerator]() + elif isinstance(accelerator, BaseAccelerator): + _ACCELERATOR = accelerator + else: + raise TypeError("accelerator must be either a string or an instance of BaseAccelerator") + + +def get_accelerator() -> BaseAccelerator: + """ + Return the accelerator for the current process. If the accelerator is not initialized, it will be initialized + to the default accelerator type. + + Returns: the accelerator for the current process. + """ + global _ACCELERATOR + + if _ACCELERATOR is None: + _ACCELERATOR = _DEFAULT_ACCELERATOR_TYPE() + return _ACCELERATOR + + + diff --git a/colossalai/accelerator/base_accelerator.py b/colossalai/accelerator/base_accelerator.py new file mode 100644 index 000000000000..74dc17d87a5e --- /dev/null +++ b/colossalai/accelerator/base_accelerator.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python +import torch +from typing import Union +from abc import ABC, abstractmethod + + +__all__ = ['BaseAccelerator'] + + +class BaseAccelerator(ABC): + + def __init__(self, + name: str, + communication_backend: str, + is_synchronous: bool) -> None: + self._name = name + self._communication_backend = communication_backend + self._is_synchronous = is_synchronous + + #======================= + # immutable attributes + #======================= + + @property + def name(self) -> str: + """ + Return the name of the accelerator. + """ + return self._name + + @property + def communication_backend(self) -> str: + """ + Return the name of the backend communication library. + """ + return self._communication_backend + + @property + def is_synchronous(self) -> bool: + """ + Return whether the accelerator is a synchronous device. + """ + return self._is_synchronous + + def __repr__(self) -> str: + cls_name = self.__class__.__name__ + return f"{cls_name}(name={self._name}, communication_backend={self._communication_backend}, is_synchronous={self._is_synchronous})" + + #======================= + # device APIs + #======================= + @abstractmethod + def current_device(self) -> int: + """ + Return the current device index. + """ + pass + + @abstractmethod + def set_device(self, device: Union[torch.device, int]) -> None: + """ + Bind the current process to a device. + """ + pass + + def get_device_name(self, device: Union[torch.device, int]) -> str: + """ + Return the name of the device. + """ + pass + + @abstractmethod + def synchronize(self, device: Union[torch.device, int] = None): + """ + Synchronize the current process. + """ + pass + + @abstractmethod + def is_available(self): + """ + Check if the accelerator is available. + """ + pass + + @abstractmethod + def device_count(self): + """ + Return the number of devices on the machine. + """ + pass + + diff --git a/colossalai/accelerator/cuda_accelerator.py b/colossalai/accelerator/cuda_accelerator.py new file mode 100644 index 000000000000..54aca84c1ee5 --- /dev/null +++ b/colossalai/accelerator/cuda_accelerator.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +import torch +from typing import Union +from .base_accelerator import BaseAccelerator + + +__all__ = ['CudaAccelerator'] + +class CudaAccelerator(BaseAccelerator): + """ + Accelerator class for Nvidia CUDA devices. + """ + + def __init__(self): + super().__init__(name="cuda", communication_backend="nccl", is_synchronous=False) + + #======================= + # device APIs + #======================= + def current_device(self) -> int: + """ + Return the current device index. + """ + return torch.cuda.current_device() + + def set_device(self, device: Union[torch.device, int]) -> None: + """ + Bind the current process to a device. + """ + torch.cuda.set_device(device) + + def get_device_name(self, device: Union[torch.device, int]) -> str: + """ + Return the name of the device. + """ + return torch.cuda.get_device_name(device) + + def synchronize(self, device: Union[torch.device, int] = None): + """ + Synchronize the current process. + """ + torch.cuda.synchronize(device) + + def is_available(self): + """ + Check if the accelerator is available. + """ + return torch.cuda.is_available() + + def device_count(self): + """ + Return the number of devices on the machine. + """ + return torch.cuda.device_count() + diff --git a/colossalai/accelerator/npu_accelerator.py b/colossalai/accelerator/npu_accelerator.py new file mode 100644 index 000000000000..ccb31766147b --- /dev/null +++ b/colossalai/accelerator/npu_accelerator.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python + +import torch +from typing import Union +from .base_accelerator import BaseAccelerator + +try: + import torch_npu +except ImportError: + pass + + +__all__ = ['NpuAccelerator'] + +class NpuAccelerator(BaseAccelerator): + """ + Accelerator class for Huawei NPU devices. + """ + + def __init__(self): + super().__init__(name="npu", communication_backend="hccl", is_synchronous=False) + + #======================= + # device APIs + #======================= + def current_device(self) -> int: + """ + Return the current device index. + """ + return torch.npu.current_device() + + def set_device(self, device: Union[torch.device, int]) -> None: + """ + Bind the current process to a device. + """ + torch.npu.set_device(device) + + def get_device_name(self, device: Union[torch.device, int]) -> str: + """ + Return the name of the device. + """ + return torch.npu.get_device_name(device) + + def synchronize(self, device: Union[torch.device, int] = None): + """ + Synchronize the current process. + """ + torch.npu.synchronize(device) + + def is_available(self): + """ + Check if the accelerator is available. + """ + return torch.npu.is_available() + + def device_count(self): + """ + Return the number of devices on the machine. + """ + return torch.npu.device_count() + From aa2117b816625d1ed2336d6099a813d3651024ba Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Tue, 28 Nov 2023 17:57:56 +0800 Subject: [PATCH 2/5] polish code --- colossalai/accelerator/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/accelerator/README.md b/colossalai/accelerator/README.md index 783ee8d7b630..7a43757d72e9 100644 --- a/colossalai/accelerator/README.md +++ b/colossalai/accelerator/README.md @@ -15,6 +15,6 @@ This module offers a layer of abstraction for ColossalAI. With this module, the Our `accelerator` module is heavily inspired by [`deepspeed/accelerator`](https://www.deepspeed.ai/tutorials/accelerator-abstraction-interface/). We found that it is a very well-designed and well-structured module that can be easily integrated into our project. We would like to thank the DeepSpeed team for their great work. -At the same time, we have implemented our own modifications:1 +We implemented this accelerator module from scratch. At the same time, we have implemented our own modifications: 1. we updated the accelerator API names to be aligned with PyTorch's native API names. 2. we did not include the `op builder` in the `accelerator`. Instead, we have reconstructed our `kernel` module to automatically match the accelerator and its corresponding kernel implementations, so as to make modules less tangled. \ No newline at end of file From bf619e91bd255de4cbba7b52bbd0ac323ce37486 Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Tue, 28 Nov 2023 18:05:03 +0800 Subject: [PATCH 3/5] polish code --- colossalai/accelerator/README.md | 2 +- colossalai/accelerator/__init__.py | 8 +----- colossalai/accelerator/api.py | 7 +++-- colossalai/accelerator/base_accelerator.py | 30 +++++++--------------- colossalai/accelerator/cuda_accelerator.py | 15 ++++++----- colossalai/accelerator/npu_accelerator.py | 18 +++++++------ 6 files changed, 32 insertions(+), 48 deletions(-) diff --git a/colossalai/accelerator/README.md b/colossalai/accelerator/README.md index 7a43757d72e9..be44fc389eb5 100644 --- a/colossalai/accelerator/README.md +++ b/colossalai/accelerator/README.md @@ -17,4 +17,4 @@ Our `accelerator` module is heavily inspired by [`deepspeed/accelerator`](https: We implemented this accelerator module from scratch. At the same time, we have implemented our own modifications: 1. we updated the accelerator API names to be aligned with PyTorch's native API names. -2. we did not include the `op builder` in the `accelerator`. Instead, we have reconstructed our `kernel` module to automatically match the accelerator and its corresponding kernel implementations, so as to make modules less tangled. \ No newline at end of file +2. we did not include the `op builder` in the `accelerator`. Instead, we have reconstructed our `kernel` module to automatically match the accelerator and its corresponding kernel implementations, so as to make modules less tangled. diff --git a/colossalai/accelerator/__init__.py b/colossalai/accelerator/__init__.py index b870ff5094ed..352458fe85d4 100644 --- a/colossalai/accelerator/__init__.py +++ b/colossalai/accelerator/__init__.py @@ -3,10 +3,4 @@ from .cuda_accelerator import CudaAccelerator from .npu_accelerator import NpuAccelerator -__all__ = [ - 'get_accelerator', - 'set_accelerator', - 'BaseAccelerator', - 'CudaAccelerator', - 'NpuAccelerator' -] \ No newline at end of file +__all__ = ["get_accelerator", "set_accelerator", "BaseAccelerator", "CudaAccelerator", "NpuAccelerator"] diff --git a/colossalai/accelerator/api.py b/colossalai/accelerator/api.py index 64c398d043c3..ae9bdba36923 100644 --- a/colossalai/accelerator/api.py +++ b/colossalai/accelerator/api.py @@ -1,10 +1,11 @@ #!/usr/bin/env python from typing import Union + from .base_accelerator import BaseAccelerator from .cuda_accelerator import CudaAccelerator from .npu_accelerator import NpuAccelerator -__all__ = ['set_accelerator', 'get_accelerator'] +__all__ = ["set_accelerator", "get_accelerator"] _ACCELERATOR = None @@ -16,6 +17,7 @@ _DEFAULT_ACCELERATOR_TYPE = CudaAccelerator + def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None: """ Set the global accelerator for the current process. @@ -46,6 +48,3 @@ def get_accelerator() -> BaseAccelerator: if _ACCELERATOR is None: _ACCELERATOR = _DEFAULT_ACCELERATOR_TYPE() return _ACCELERATOR - - - diff --git a/colossalai/accelerator/base_accelerator.py b/colossalai/accelerator/base_accelerator.py index 74dc17d87a5e..71d03b8d62e8 100644 --- a/colossalai/accelerator/base_accelerator.py +++ b/colossalai/accelerator/base_accelerator.py @@ -1,25 +1,21 @@ #!/usr/bin/env python -import torch -from typing import Union from abc import ABC, abstractmethod +from typing import Union +import torch -__all__ = ['BaseAccelerator'] +__all__ = ["BaseAccelerator"] class BaseAccelerator(ABC): - - def __init__(self, - name: str, - communication_backend: str, - is_synchronous: bool) -> None: + def __init__(self, name: str, communication_backend: str, is_synchronous: bool) -> None: self._name = name self._communication_backend = communication_backend self._is_synchronous = is_synchronous - #======================= + # ======================= # immutable attributes - #======================= + # ======================= @property def name(self) -> str: @@ -34,7 +30,7 @@ def communication_backend(self) -> str: Return the name of the backend communication library. """ return self._communication_backend - + @property def is_synchronous(self) -> bool: """ @@ -46,48 +42,40 @@ def __repr__(self) -> str: cls_name = self.__class__.__name__ return f"{cls_name}(name={self._name}, communication_backend={self._communication_backend}, is_synchronous={self._is_synchronous})" - #======================= + # ======================= # device APIs - #======================= + # ======================= @abstractmethod def current_device(self) -> int: """ Return the current device index. """ - pass @abstractmethod def set_device(self, device: Union[torch.device, int]) -> None: """ Bind the current process to a device. """ - pass def get_device_name(self, device: Union[torch.device, int]) -> str: """ Return the name of the device. """ - pass @abstractmethod def synchronize(self, device: Union[torch.device, int] = None): """ Synchronize the current process. """ - pass @abstractmethod def is_available(self): """ Check if the accelerator is available. """ - pass @abstractmethod def device_count(self): """ Return the number of devices on the machine. """ - pass - - diff --git a/colossalai/accelerator/cuda_accelerator.py b/colossalai/accelerator/cuda_accelerator.py index 54aca84c1ee5..72152834ab33 100644 --- a/colossalai/accelerator/cuda_accelerator.py +++ b/colossalai/accelerator/cuda_accelerator.py @@ -1,10 +1,12 @@ #!/usr/bin/env python -import torch from typing import Union + +import torch + from .base_accelerator import BaseAccelerator +__all__ = ["CudaAccelerator"] -__all__ = ['CudaAccelerator'] class CudaAccelerator(BaseAccelerator): """ @@ -12,11 +14,11 @@ class CudaAccelerator(BaseAccelerator): """ def __init__(self): - super().__init__(name="cuda", communication_backend="nccl", is_synchronous=False) - - #======================= + super().__init__(name="cuda", communication_backend="nccl", is_synchronous=False) + + # ======================= # device APIs - #======================= + # ======================= def current_device(self) -> int: """ Return the current device index. @@ -52,4 +54,3 @@ def device_count(self): Return the number of devices on the machine. """ return torch.cuda.device_count() - diff --git a/colossalai/accelerator/npu_accelerator.py b/colossalai/accelerator/npu_accelerator.py index ccb31766147b..01bbb0a3c550 100644 --- a/colossalai/accelerator/npu_accelerator.py +++ b/colossalai/accelerator/npu_accelerator.py @@ -1,16 +1,19 @@ #!/usr/bin/env python -import torch from typing import Union + +import torch + from .base_accelerator import BaseAccelerator try: - import torch_npu + pass except ImportError: pass -__all__ = ['NpuAccelerator'] +__all__ = ["NpuAccelerator"] + class NpuAccelerator(BaseAccelerator): """ @@ -18,11 +21,11 @@ class NpuAccelerator(BaseAccelerator): """ def __init__(self): - super().__init__(name="npu", communication_backend="hccl", is_synchronous=False) - - #======================= + super().__init__(name="npu", communication_backend="hccl", is_synchronous=False) + + # ======================= # device APIs - #======================= + # ======================= def current_device(self) -> int: """ Return the current device index. @@ -58,4 +61,3 @@ def device_count(self): Return the number of devices on the machine. """ return torch.npu.device_count() - From 9e8d796b3772d3d9518ea9977f21985bfaca5cf8 Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Wed, 29 Nov 2023 10:38:17 +0800 Subject: [PATCH 4/5] polish code --- colossalai/accelerator/__init__.py | 11 +++++-- colossalai/accelerator/api.py | 35 +++++++++++++++++++---- colossalai/accelerator/npu_accelerator.py | 2 +- 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/colossalai/accelerator/__init__.py b/colossalai/accelerator/__init__.py index 352458fe85d4..d144235d3b50 100644 --- a/colossalai/accelerator/__init__.py +++ b/colossalai/accelerator/__init__.py @@ -1,6 +1,13 @@ -from .api import get_accelerator, set_accelerator +from .api import auto_set_accelerator, get_accelerator, set_accelerator from .base_accelerator import BaseAccelerator from .cuda_accelerator import CudaAccelerator from .npu_accelerator import NpuAccelerator -__all__ = ["get_accelerator", "set_accelerator", "BaseAccelerator", "CudaAccelerator", "NpuAccelerator"] +__all__ = [ + "get_accelerator", + "set_accelerator", + "auto_set_accelerator", + "BaseAccelerator", + "CudaAccelerator", + "NpuAccelerator", +] diff --git a/colossalai/accelerator/api.py b/colossalai/accelerator/api.py index ae9bdba36923..96ba98e9d5dd 100644 --- a/colossalai/accelerator/api.py +++ b/colossalai/accelerator/api.py @@ -1,19 +1,21 @@ #!/usr/bin/env python +from collections import OrderedDict from typing import Union from .base_accelerator import BaseAccelerator from .cuda_accelerator import CudaAccelerator from .npu_accelerator import NpuAccelerator -__all__ = ["set_accelerator", "get_accelerator"] +__all__ = ["set_accelerator", "auto_set_accelerator", "get_accelerator"] _ACCELERATOR = None -_ACCELERATOR_MAPPING = { - "cuda": CudaAccelerator, - "npu": NpuAccelerator, -} + +# we use ordered dictionary here to associate the +# order with device check priority +# i.e. auto_set_accelerator will check cuda first +_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator) _DEFAULT_ACCELERATOR_TYPE = CudaAccelerator @@ -36,6 +38,29 @@ def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None: raise TypeError("accelerator must be either a string or an instance of BaseAccelerator") +def auto_set_accelerator() -> None: + """ + Automatically check if any accelerator is available. + If an accelerator is availabe, set it as the global accelerator. + """ + global _ACCELERATOR + + for _, accelerator_cls in _ACCELERATOR_MAPPING.items(): + try: + accelerator = accelerator_cls() + + if accelerator.is_available(): + _ACCELERATOR = accelerator + break + except RuntimeError: + pass + + if _ACCELERATOR is None: + raise RuntimeError( + f"No accelerator is available. Please check your environment. The list of accelerators we support is {list(_ACCELERATOR_MAPPING.keys())}" + ) + + def get_accelerator() -> BaseAccelerator: """ Return the accelerator for the current process. If the accelerator is not initialized, it will be initialized diff --git a/colossalai/accelerator/npu_accelerator.py b/colossalai/accelerator/npu_accelerator.py index 01bbb0a3c550..a8bba6eaf8a7 100644 --- a/colossalai/accelerator/npu_accelerator.py +++ b/colossalai/accelerator/npu_accelerator.py @@ -7,7 +7,7 @@ from .base_accelerator import BaseAccelerator try: - pass + import torch_npu # noqa except ImportError: pass From d4229316c894c109a3b77a259b321425ccef1e63 Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Wed, 29 Nov 2023 10:50:24 +0800 Subject: [PATCH 5/5] polish code --- colossalai/accelerator/README.md | 2 +- colossalai/accelerator/api.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/colossalai/accelerator/README.md b/colossalai/accelerator/README.md index be44fc389eb5..8c644493b03a 100644 --- a/colossalai/accelerator/README.md +++ b/colossalai/accelerator/README.md @@ -9,7 +9,7 @@ ## 📚 Introduction -This module offers a layer of abstraction for ColossalAI. With this module, the user can easily switch between different accelerator backends, such as Nvidia GPUs, Huawei NPUs, etc. This module is an attempt to make users' code portable across different hardware platform with a simple `set_accelerator()` API. +This module offers a layer of abstraction for ColossalAI. With this module, the user can easily switch between different accelerator backends, such as Nvidia GPUs, Huawei NPUs, etc. This module is an attempt to make users' code portable across different hardware platform with a simple `auto_set_accelerator()` API. ## 📌 Design and Acknowledgement diff --git a/colossalai/accelerator/api.py b/colossalai/accelerator/api.py index 96ba98e9d5dd..393340b71f34 100644 --- a/colossalai/accelerator/api.py +++ b/colossalai/accelerator/api.py @@ -17,8 +17,6 @@ # i.e. auto_set_accelerator will check cuda first _ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator) -_DEFAULT_ACCELERATOR_TYPE = CudaAccelerator - def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None: """ @@ -48,11 +46,10 @@ def auto_set_accelerator() -> None: for _, accelerator_cls in _ACCELERATOR_MAPPING.items(): try: accelerator = accelerator_cls() - if accelerator.is_available(): _ACCELERATOR = accelerator break - except RuntimeError: + except: pass if _ACCELERATOR is None: @@ -71,5 +68,5 @@ def get_accelerator() -> BaseAccelerator: global _ACCELERATOR if _ACCELERATOR is None: - _ACCELERATOR = _DEFAULT_ACCELERATOR_TYPE() + auto_set_accelerator() return _ACCELERATOR