From 9c667fb24436736400f76c07b983381164ed87f2 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Wed, 6 Dec 2023 16:45:22 +0800 Subject: [PATCH 01/21] update extension --- colossalai/kernel/base_kernel_loader.py | 48 ++++ colossalai/kernel/cpu_adam_loader.py | 22 ++ colossalai/kernel/extensions/__init__.py | 0 .../kernel/extensions/base_extension.py | 34 +++ .../kernel/extensions/cpu_adam/__init__.py | 4 + .../extensions/cpu_adam/arm_extension.py | 53 ++++ .../extensions/cpu_adam/x86_extension.py | 65 +++++ .../kernel/extensions/extension_builder.py | 241 ++++++++++++++++++ colossalai/kernel/extensions/utils.py | 229 +++++++++++++++++ 9 files changed, 696 insertions(+) create mode 100644 colossalai/kernel/base_kernel_loader.py create mode 100644 colossalai/kernel/cpu_adam_loader.py create mode 100644 colossalai/kernel/extensions/__init__.py create mode 100644 colossalai/kernel/extensions/base_extension.py create mode 100644 colossalai/kernel/extensions/cpu_adam/__init__.py create mode 100644 colossalai/kernel/extensions/cpu_adam/arm_extension.py create mode 100644 colossalai/kernel/extensions/cpu_adam/x86_extension.py create mode 100644 colossalai/kernel/extensions/extension_builder.py create mode 100644 colossalai/kernel/extensions/utils.py diff --git a/colossalai/kernel/base_kernel_loader.py b/colossalai/kernel/base_kernel_loader.py new file mode 100644 index 000000000000..96574b5ade4a --- /dev/null +++ b/colossalai/kernel/base_kernel_loader.py @@ -0,0 +1,48 @@ +import platform +from abc import ABC, abstractmethod +from typing import Dict, List + +import torch + +from .extensions.base_extension import BaseExtension + + +class BaseKernelLoader(ABC): + """ + Usage: + kernel_loader = KernelLoader() + kernel = kernel_loader.load() + """ + + def __init__(self, extension_map: Dict[str, BaseExtension], supported_device: List[str]): + self._extension_map = extension_map + self._supported_device = supported_device + + def run_checks(self): + # run supported device check and other possible checks + pass + + @abstractmethod + def fetch_kernel(self): + pass + + def load(self): + self.run_checks() + return self.fetch_kernel() + + def _is_x86(self) -> bool: + return platform.processor() == "x86_64" + + def _is_arm(self) -> bool: + return platform.processor() == "aarch64" + + def _is_cuda(self) -> bool: + return torch.cuda.is_available() + + def _is_npu(self) -> bool: + try: + import torch_npu # noqa + + return torch.npu.is_available() + except: + return False diff --git a/colossalai/kernel/cpu_adam_loader.py b/colossalai/kernel/cpu_adam_loader.py new file mode 100644 index 000000000000..3998752d411d --- /dev/null +++ b/colossalai/kernel/cpu_adam_loader.py @@ -0,0 +1,22 @@ +from .base_kernel_loader import BaseKernelLoader +from .extensions.cpu_adam import ArmCPUAdamExtension, X86CPUAdamExtension + + +class CPUAdamLoader(BaseKernelLoader): + def __init__(self): + super().__init__( + extension_map=dict( + arm=ArmCPUAdamExtension, + x86=X86CPUAdamExtension, + ), + supported_device=["cpu"], + ) + + def fetch_kernel(self): + if self._is_x86(): + kernel = self._extension_map["x86"].fetch() + elif self._is_arm(): + kernel = self._extension_map["arm"].fetch() + else: + raise Exception("not supported") + return kernel diff --git a/colossalai/kernel/extensions/__init__.py b/colossalai/kernel/extensions/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/kernel/extensions/base_extension.py b/colossalai/kernel/extensions/base_extension.py new file mode 100644 index 000000000000..cd7398959b21 --- /dev/null +++ b/colossalai/kernel/extensions/base_extension.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod +from typing import Callable + + +class BaseExtension(ABC): + @abstractmethod + @property + def build_completed(self) -> bool: + pass + + @abstractmethod + def build(self) -> None: + pass + + @abstractmethod + def load(self) -> Callable: + pass + + def fetch(self) -> Callable: + if not self.build_completed: + self.build() + return self.load() + + +class CUDAExtension(BaseExtension): + pass + + +class TritonExtension(BaseExtension): + pass + + +class NPUExtension(BaseExtension): + pass diff --git a/colossalai/kernel/extensions/cpu_adam/__init__.py b/colossalai/kernel/extensions/cpu_adam/__init__.py new file mode 100644 index 000000000000..b14f3a978b0f --- /dev/null +++ b/colossalai/kernel/extensions/cpu_adam/__init__.py @@ -0,0 +1,4 @@ +from .arm_extension import ArmCPUAdamExtension +from .x86_extension import X86CPUAdamExtension + +__all__ = ["ArmCPUAdamExtension", "X86CPUAdamExtension"] diff --git a/colossalai/kernel/extensions/cpu_adam/arm_extension.py b/colossalai/kernel/extensions/cpu_adam/arm_extension.py new file mode 100644 index 000000000000..d76755f7fbd5 --- /dev/null +++ b/colossalai/kernel/extensions/cpu_adam/arm_extension.py @@ -0,0 +1,53 @@ +from ..base_extension import BaseExtension +from ..extension_builder import ExtensionBuilder + + +class ArmCPUAdamExtension(BaseExtension): + def __init__(self) -> None: + super().__init__() + self.kernel_builder = ArmCPUAdamBuilder() + self._is_build_completed = False + + @property + def build_completed(self): + return self._is_build_completed + + def build(self): + self.kernel_builder.build() + self._is_build_completed = True + + def load(self): + return self.kernel_builder.load() + + +class ArmCPUAdamBuilder(ExtensionBuilder): + NAME = "arm_cpu_adam" + PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam" + ext_type = "cpu" + + def __init__(self): + super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH) + self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + + # necessary 4 functions + def sources_files(self): + ret = [ + self.csrc_abs_path("cpu_adam_arm.cpp"), + ] + return ret + + def include_dirs(self): + return [self.csrc_abs_path("includes")] + + def cxx_flags(self): + extra_cxx_flags = [ + "-std=c++14", + "-std=c++17", + "-g", + "-Wno-reorder", + "-fopenmp", + ] + return ["-O3"] + self.version_dependent_macros + extra_cxx_flags + + def nvcc_flags(self): + return [] diff --git a/colossalai/kernel/extensions/cpu_adam/x86_extension.py b/colossalai/kernel/extensions/cpu_adam/x86_extension.py new file mode 100644 index 000000000000..024cb633147a --- /dev/null +++ b/colossalai/kernel/extensions/cpu_adam/x86_extension.py @@ -0,0 +1,65 @@ +from ..base_extension import BaseExtension +from ..extension_builder import ExtensionBuilder +from ..utils import append_nvcc_threads + + +class X86CPUAdamExtension(BaseExtension): + def __init__(self) -> None: + super().__init__() + self.kernel_builder = X86CPUAdamBuilder() + self._is_build_completed = False + + @property + def build_completed(self): + return self._is_build_completed + + def build(self): + self.kernel_builder.build() + self._is_build_completed = True + + def load(self): + return self.kernel_builder.load() + + +class X86CPUAdamBuilder(ExtensionBuilder): + NAME = "cpu_adam" + PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam" + + def __init__(self): + super().__init__(name=X86CPUAdamBuilder.NAME, prebuilt_import_path=X86CPUAdamBuilder.PREBUILT_IMPORT_PATH) + self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + + # necessary 4 functions + def sources_files(self): + ret = [ + self.csrc_abs_path("cpu_adam.cpp"), + ] + return ret + + def include_dirs(self): + return [self.csrc_abs_path("includes"), self.get_cuda_home_include()] + + def cxx_flags(self): + extra_cxx_flags = [ + "-std=c++14", + "-std=c++17", + "-lcudart", + "-lcublas", + "-g", + "-Wno-reorder", + "-fopenmp", + "-march=native", + ] + return ["-O3"] + self.version_dependent_macros + extra_cxx_flags + + def nvcc_flags(self): + extra_cuda_flags = [ + "-std=c++14", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", + ] + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/colossalai/kernel/extensions/extension_builder.py b/colossalai/kernel/extensions/extension_builder.py new file mode 100644 index 000000000000..cb1a0faeba87 --- /dev/null +++ b/colossalai/kernel/extensions/extension_builder.py @@ -0,0 +1,241 @@ +# This code has been adapted from the DeepSpeed library. +# Copyright (c) Microsoft Corporation. + +# Licensed under the MIT License. +import importlib +import os +import time +from abc import ABC, abstractmethod +from pathlib import Path +from typing import List, Optional, Union + +from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0 + + +class ExtensionBuilder(ABC): + """ + Builder is the base class to build extensions for PyTorch. + + Args: + name (str): the name of the kernel to be built + prebuilt_import_path (str): the path where the extension is installed during pip install + """ + + ext_type: str = "cuda" + + def __init__(self, name: str, prebuilt_import_path: str): + self.name = name + self.prebuilt_import_path = prebuilt_import_path + self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + + # we store the op as an attribute to avoid repeated building and loading + self.cached_op_module = None + + assert prebuilt_import_path.startswith( + "colossalai._C" + ), f"The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}" + + def relative_to_abs_path(self, code_path: str) -> str: + """ + This function takes in a path relative to the colossalai root directory and return the absolute path. + """ + op_builder_module_path = Path(__file__).parent + + # if we install from source + # the current file path will be op_builder/builder.py + # if we install via pip install colossalai + # the current file path will be colossalai/kernel/op_builder/builder.py + # this is because that the op_builder inside colossalai is a symlink + # this symlink will be replaced with actual files if we install via pypi + # thus we cannot tell the colossalai root directory by checking whether the op_builder + # is a symlink, we can only tell whether it is inside or outside colossalai + if str(op_builder_module_path).endswith("colossalai/kernel/op_builder"): + root_path = op_builder_module_path.parent.parent + else: + root_path = op_builder_module_path.parent.joinpath("colossalai") + + code_abs_path = root_path.joinpath(code_path) + return str(code_abs_path) + + def get_cuda_home_include(self): + """ + return include path inside the cuda home. + """ + from torch.utils.cpp_extension import CUDA_HOME + + if CUDA_HOME is None: + raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.") + cuda_include = os.path.join(CUDA_HOME, "include") + return cuda_include + + def csrc_abs_path(self, path): + return os.path.join(self.relative_to_abs_path("kernel/cuda_native/csrc"), path) + + # functions must be overrided begin + @abstractmethod + def sources_files(self) -> List[str]: + """ + This function should return a list of source files for extensions. + """ + raise NotImplementedError + + @abstractmethod + def include_dirs(self) -> List[str]: + """ + This function should return a list of include files for extensions. + """ + + @abstractmethod + def cxx_flags(self) -> List[str]: + """ + This function should return a list of cxx compilation flags for extensions. + """ + + @abstractmethod + def nvcc_flags(self) -> List[str]: + """ + This function should return a list of nvcc compilation flags for extensions. + """ + + # functions must be overrided over + def strip_empty_entries(self, args): + """ + Drop any empty strings from the list of compile and link flags + """ + return [x for x in args if len(x) > 0] + + def import_op(self): + """ + This function will import the op module by its string name. + """ + return importlib.import_module(self.prebuilt_import_path) + + def check_runtime_build_environment(self): + """ + Check whether the system environment is ready for extension compilation. + """ + try: + from torch.utils.cpp_extension import CUDA_HOME + + TORCH_AVAILABLE = True + except ImportError: + TORCH_AVAILABLE = False + CUDA_HOME = None + + if not TORCH_AVAILABLE: + raise ModuleNotFoundError( + "PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions" + ) + + if CUDA_HOME is None: + raise RuntimeError( + "CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build CUDA extensions" + ) + + # make sure CUDA is available for compilation during + cuda_available = check_cuda_availability() + if not cuda_available: + raise RuntimeError("CUDA is not available on your system as torch.cuda.is_available() returns False.") + + # make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not + check_system_pytorch_cuda_match(CUDA_HOME) + + def build(self, verbose: Optional[bool] = None): + """ + If the kernel is not built during pip install, it will build the kernel. + If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the + kernel is built during pip install, it can be accessed through `colossalai._C`. + + Warning: do not load this kernel repeatedly during model execution as it could slow down the training process. + + Args: + verbose (bool, optional): show detailed info. Defaults to True. + """ + if verbose is None: + verbose = os.environ.get("CAI_KERNEL_VERBOSE", "0") == "1" + try: + # if the kernel has been pre-built during installation + # we just directly import it + op_module = self.import_op() + if verbose: + print_rank_0( + f"[extension] OP {self.prebuilt_import_path} has been compiled ahead of time, skip building." + ) + except ImportError: + # check environment + if self.ext_type == "cuda": + self.check_runtime_build_environment() + + # time the kernel compilation + start_build = time.time() + + # construct the build directory + import torch + from torch.utils.cpp_extension import load + + torch_version_major = torch.__version__.split(".")[0] + torch_version_minor = torch.__version__.split(".")[1] + torch_cuda_version = torch.version.cuda + home_directory = os.path.expanduser("~") + extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}" + build_directory = os.path.join(home_directory, extension_directory) + Path(build_directory).mkdir(parents=True, exist_ok=True) + + if verbose: + print_rank_0(f"[extension] Compiling or loading the JIT-built {self.name} kernel during runtime now") + + # load the kernel + op_module = load( + name=self.name, + sources=self.strip_empty_entries(self.sources_files()), + extra_include_paths=self.strip_empty_entries(self.include_dirs()), + extra_cflags=self.cxx_flags(), + extra_cuda_cflags=self.nvcc_flags(), + extra_ldflags=[], + build_directory=build_directory, + verbose=verbose, + ) + + build_duration = time.time() - start_build + + # log jit compilation time + if verbose: + print_rank_0(f"[extension] Time to compile or load {self.name} op: {build_duration} seconds") + + # cache the built/loaded kernel + self.cached_op_module = op_module + + def load(self, verbose: Optional[bool] = None): + """ + load the kernel during runtime. + + Args: + verbose (bool, optional): show detailed info. Defaults to True. + """ + # if the kernel has be compiled and cached, we directly use it + assert self.cached_op_module is not None, "Please build the kernel first before loading it." + return self.cached_op_module + + def builder(self) -> Union["CUDAExtension", "CppExtension"]: + """ + get a CUDAExtension instance used for setup.py + """ + from torch.utils.cpp_extension import CppExtension, CUDAExtension + + if self.ext_type == "cpp": + return CppExtension( + name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args=self.strip_empty_entries(self.cxx_flags()), + ) + + return CUDAExtension( + name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args={ + "cxx": self.strip_empty_entries(self.cxx_flags()), + "nvcc": self.strip_empty_entries(self.nvcc_flags()), + }, + ) diff --git a/colossalai/kernel/extensions/utils.py b/colossalai/kernel/extensions/utils.py new file mode 100644 index 000000000000..3f75f952d57b --- /dev/null +++ b/colossalai/kernel/extensions/utils.py @@ -0,0 +1,229 @@ +import os +import re +import subprocess +import warnings +from typing import List + + +def print_rank_0(message: str) -> None: + """ + Print on only one process to avoid spamming. + """ + try: + import torch.distributed as dist + + if not dist.is_initialized(): + is_main_rank = True + else: + is_main_rank = dist.get_rank() == 0 + except ImportError: + is_main_rank = True + + if is_main_rank: + print(message) + + +def get_cuda_version_in_pytorch() -> List[int]: + """ + This function returns the CUDA version in the PyTorch build. + + Returns: + The CUDA version required by PyTorch, in the form of tuple (major, minor). + """ + import torch + + try: + torch_cuda_major = torch.version.cuda.split(".")[0] + torch_cuda_minor = torch.version.cuda.split(".")[1] + except: + raise ValueError( + "[extension] Cannot retrieve the CUDA version in the PyTorch binary given by torch.version.cuda" + ) + return torch_cuda_major, torch_cuda_minor + + +def get_cuda_bare_metal_version(cuda_dir) -> List[int]: + """ + Get the System CUDA version from nvcc. + + Args: + cuda_dir (str): the directory for CUDA Toolkit. + + Returns: + The CUDA version required by PyTorch, in the form of tuple (major, minor). + """ + nvcc_path = os.path.join(cuda_dir, "bin/nvcc") + + if cuda_dir is None: + raise ValueError( + f"[extension] The argument cuda_dir is None, but expected to be a string. Please make sure your have exported the environment variable CUDA_HOME correctly." + ) + + # check for nvcc path + if not os.path.exists(nvcc_path): + raise FileNotFoundError( + f"[extension] The nvcc compiler is not found in {nvcc_path}, please make sure you have set the correct value for CUDA_HOME." + ) + + # parse the nvcc -v output to obtain the system cuda version + try: + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + except: + raise ValueError( + f"[extension] Failed to parse the nvcc output to obtain the system CUDA bare metal version. The output for 'nvcc -v' is \n{raw_output}" + ) + + return bare_metal_major, bare_metal_minor + + +def check_system_pytorch_cuda_match(cuda_dir): + bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) + torch_cuda_major, torch_cuda_minor = get_cuda_version_in_pytorch() + + if bare_metal_major != torch_cuda_major: + raise Exception( + f"[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) " + f"mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor})." + "Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ ." + ) + + if bare_metal_minor != torch_cuda_minor: + warnings.warn( + f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. " + "The mismatch is found in the minor version. As the APIs are compatible, we will allow compilation to proceed. " + "If you encounter any issue when using the built kernel, please try to build it again with fully matched CUDA versions" + ) + return True + + +def get_pytorch_version() -> List[int]: + """ + This functions finds the PyTorch version. + + Returns: + A tuple of integers in the form of (major, minor, patch). + """ + import torch + + torch_version = torch.__version__.split("+")[0] + TORCH_MAJOR = int(torch_version.split(".")[0]) + TORCH_MINOR = int(torch_version.split(".")[1]) + TORCH_PATCH = int(torch_version.split(".")[2], 16) + return TORCH_MAJOR, TORCH_MINOR, TORCH_PATCH + + +def check_pytorch_version(min_major_version, min_minor_version) -> bool: + """ + Compare the current PyTorch version with the minium required version. + + Args: + min_major_version (int): the minimum major version of PyTorch required + min_minor_version (int): the minimum minor version of PyTorch required + + Returns: + A boolean value. The value is True if the current pytorch version is acceptable and False otherwise. + """ + # get pytorch version + torch_major, torch_minor, _ = get_pytorch_version() + + # if the + if torch_major < min_major_version or (torch_major == min_major_version and torch_minor < min_minor_version): + raise RuntimeError( + f"[extension] Colossal-AI requires Pytorch {min_major_version}.{min_minor_version} or newer.\n" + "The latest stable release can be obtained from https://pytorch.org/get-started/locally/" + ) + + +def check_cuda_availability(): + """ + Check if CUDA is available on the system. + + Returns: + A boolean value. True if CUDA is available and False otherwise. + """ + import torch + + return torch.cuda.is_available() + + +def set_cuda_arch_list(cuda_dir): + """ + This function sets the PyTorch TORCH_CUDA_ARCH_LIST variable for ahead-of-time extension compilation. + Ahead-of-time compilation occurs when CUDA_EXT=1 is set when running 'pip install'. + """ + cuda_available = check_cuda_availability() + + # we only need to set this when CUDA is not available for cross-compilation + if not cuda_available: + warnings.warn( + "\n[extension] PyTorch did not find available GPUs on this system.\n" + "If your intention is to cross-compile, this is not an error.\n" + "By default, Colossal-AI will cross-compile for \n" + "1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n" + "2. Volta (compute capability 7.0)\n" + "3. Turing (compute capability 7.5),\n" + "4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n" + "\nIf you wish to cross-compile for a single specific architecture,\n" + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n' + ) + + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: + bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) + + arch_list = ["6.0", "6.1", "6.2", "7.0", "7.5"] + + if int(bare_metal_major) == 11: + if int(bare_metal_minor) == 0: + arch_list.append("8.0") + else: + arch_list.append("8.0") + arch_list.append("8.6") + + arch_list_str = ";".join(arch_list) + os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list_str + return False + return True + + +def get_cuda_cc_flag() -> List[str]: + """ + This function produces the cc flags for your GPU arch + + Returns: + The CUDA cc flags for compilation. + """ + + # only import torch when needed + # this is to avoid importing torch when building on a machine without torch pre-installed + # one case is to build wheel for pypi release + import torch + + cc_flag = [] + max_arch = "".join(str(i) for i in torch.cuda.get_device_capability()) + for arch in torch.cuda.get_arch_list(): + res = re.search(r"sm_(\d+)", arch) + if res: + arch_cap = res[1] + if int(arch_cap) >= 60 and int(arch_cap) <= int(max_arch): + cc_flag.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"]) + return cc_flag + + +def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]: + """ + This function appends the threads flag to your nvcc args. + + Returns: + The nvcc compilation flags including the threads flag. + """ + from torch.utils.cpp_extension import CUDA_HOME + + bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) + if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args From f903fe6547499f4c2e2e6cad9763d63b5b5f152e Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Wed, 6 Dec 2023 18:02:38 +0800 Subject: [PATCH 02/21] update cpu adam --- colossalai/kernel/__init__.py | 2 ++ colossalai/kernel/cpu_adam_loader.py | 4 ++-- colossalai/kernel/extensions/base_extension.py | 1 - colossalai/kernel/extensions/extension_builder.py | 2 ++ colossalai/nn/optimizer/cpu_adam.py | 5 ++--- tests/test_optimizer/test_adam_kernel.py | 4 ++-- 6 files changed, 10 insertions(+), 8 deletions(-) diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py index 8933fc0a3c2f..6f5b072aa4ab 100644 --- a/colossalai/kernel/__init__.py +++ b/colossalai/kernel/__init__.py @@ -1,7 +1,9 @@ +from .cpu_adam_loader import CPUAdamLoader from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention __all__ = [ "LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention", + "CPUAdamLoader", ] diff --git a/colossalai/kernel/cpu_adam_loader.py b/colossalai/kernel/cpu_adam_loader.py index 3998752d411d..d525bea668de 100644 --- a/colossalai/kernel/cpu_adam_loader.py +++ b/colossalai/kernel/cpu_adam_loader.py @@ -14,9 +14,9 @@ def __init__(self): def fetch_kernel(self): if self._is_x86(): - kernel = self._extension_map["x86"].fetch() + kernel = self._extension_map["x86"]().fetch() elif self._is_arm(): - kernel = self._extension_map["arm"].fetch() + kernel = self._extension_map["arm"]().fetch() else: raise Exception("not supported") return kernel diff --git a/colossalai/kernel/extensions/base_extension.py b/colossalai/kernel/extensions/base_extension.py index cd7398959b21..dd070a238292 100644 --- a/colossalai/kernel/extensions/base_extension.py +++ b/colossalai/kernel/extensions/base_extension.py @@ -4,7 +4,6 @@ class BaseExtension(ABC): @abstractmethod - @property def build_completed(self) -> bool: pass diff --git a/colossalai/kernel/extensions/extension_builder.py b/colossalai/kernel/extensions/extension_builder.py index cb1a0faeba87..5849fcfa6afa 100644 --- a/colossalai/kernel/extensions/extension_builder.py +++ b/colossalai/kernel/extensions/extension_builder.py @@ -51,6 +51,8 @@ def relative_to_abs_path(self, code_path: str) -> str: # is a symlink, we can only tell whether it is inside or outside colossalai if str(op_builder_module_path).endswith("colossalai/kernel/op_builder"): root_path = op_builder_module_path.parent.parent + elif str(op_builder_module_path).endswith("colossalai/kernel/extensions"): + root_path = op_builder_module_path.parent.parent else: root_path = op_builder_module_path.parent.joinpath("colossalai") diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 7d53a1dd6834..b2f67cae61d1 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -1,10 +1,9 @@ import math -import platform from typing import Optional import torch -from colossalai.kernel.op_builder import ArmCPUAdamBuilder, CPUAdamBuilder +from colossalai.kernel import CPUAdamLoader from .nvme_optimizer import NVMeOptimizer @@ -78,7 +77,7 @@ def __init__( default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) self.adamw_mode = adamw_mode - cpu_adam = ArmCPUAdamBuilder().load() if platform.machine() == "aarch64" else CPUAdamBuilder().load() + cpu_adam = CPUAdamLoader().load() # if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py index 6bbe3e4e8172..c136f78a1d60 100644 --- a/tests/test_optimizer/test_adam_kernel.py +++ b/tests/test_optimizer/test_adam_kernel.py @@ -90,9 +90,9 @@ def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_av class CPUAdamKernel(AdamKernel): def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) - from colossalai.kernel.op_builder import CPUAdamBuilder + from colossalai.kernel import CPUAdamLoader - cpu_optim = CPUAdamBuilder().load() + cpu_optim = CPUAdamLoader().load() self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw) From 6478272874dc46be4256e7808e3ca77338d18e6c Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 7 Dec 2023 11:50:40 +0800 Subject: [PATCH 03/21] update is --- colossalai/kernel/base_kernel_loader.py | 8 ++++---- colossalai/kernel/cpu_adam_loader.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/colossalai/kernel/base_kernel_loader.py b/colossalai/kernel/base_kernel_loader.py index 96574b5ade4a..7977228391a2 100644 --- a/colossalai/kernel/base_kernel_loader.py +++ b/colossalai/kernel/base_kernel_loader.py @@ -30,16 +30,16 @@ def load(self): self.run_checks() return self.fetch_kernel() - def _is_x86(self) -> bool: + def _is_x86_available(self) -> bool: return platform.processor() == "x86_64" - def _is_arm(self) -> bool: + def _is_arm_available(self) -> bool: return platform.processor() == "aarch64" - def _is_cuda(self) -> bool: + def _is_cuda_available(self) -> bool: return torch.cuda.is_available() - def _is_npu(self) -> bool: + def _is_npu_available(self) -> bool: try: import torch_npu # noqa diff --git a/colossalai/kernel/cpu_adam_loader.py b/colossalai/kernel/cpu_adam_loader.py index 3998752d411d..5ccc049b1b93 100644 --- a/colossalai/kernel/cpu_adam_loader.py +++ b/colossalai/kernel/cpu_adam_loader.py @@ -13,9 +13,9 @@ def __init__(self): ) def fetch_kernel(self): - if self._is_x86(): + if self._is_x86_available(): kernel = self._extension_map["x86"].fetch() - elif self._is_arm(): + elif self._is_arm_available(): kernel = self._extension_map["arm"].fetch() else: raise Exception("not supported") From cc6c21c69bd24fdbcb9119ba8adf9c02613f344c Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 7 Dec 2023 13:55:01 +0800 Subject: [PATCH 04/21] add doc for cpu adam --- colossalai/kernel/cpu_adam_loader.py | 39 ++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/colossalai/kernel/cpu_adam_loader.py b/colossalai/kernel/cpu_adam_loader.py index 5ccc049b1b93..39103de869a4 100644 --- a/colossalai/kernel/cpu_adam_loader.py +++ b/colossalai/kernel/cpu_adam_loader.py @@ -3,6 +3,45 @@ class CPUAdamLoader(BaseKernelLoader): + """ + CPU Adam Loader + + Usage: + # init + cpu_adam = CPUAdamLoader().load() + cpu_adam_op = cpu_adam.CPUAdamOptimizer( + alpha, beta1, beta2, epsilon, weight_decay, adamw_mode + ) + ... + # optim step + cpu_adam_op.step( + step, lr, beta1, beta2, epsilon, weight_decay, bias_correction, + params, grads, exp_avg, exp_avg_sq, loss_scale, + ) + + Args: + CPUAdamOptimizer: + alpha (float): learning rate. Default to 1e-3. + beta1 (float): coefficients used for computing running averages of gradient. Default to 0.9. + beta2 (float): coefficients used for computing running averages of its square. Default to 0.99. + epsilon (float): term added to the denominator to improve numerical stability. Default to 1e-8. + weight_decay (float): weight decay (L2 penalty). Default to 0. + adamw_mode (bool): whether to use the adamw. Default to True. + step: + step (int): current step. + lr (float): learning rate. + beta1 (float): coefficients used for computing running averages of gradient. + beta2 (float): coefficients used for computing running averages of its square. + epsilon (float): term added to the denominator to improve numerical stability. + weight_decay (float): weight decay (L2 penalty). + bias_correction (bool): whether to use bias correction. + params (torch.Tensor): parameter. + grads (torch.Tensor): gradient. + exp_avg (torch.Tensor): exp average. + exp_avg_sq (torch.Tensor): exp average square. + loss_scale (float): loss scale value. + """ + def __init__(self): super().__init__( extension_map=dict( From 7f8979a6af5bfc694ec4fa5ec5fa086b117cf2f5 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 7 Dec 2023 14:02:11 +0800 Subject: [PATCH 05/21] update kernel --- colossalai/kernel/cpu_adam_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/kernel/cpu_adam_loader.py b/colossalai/kernel/cpu_adam_loader.py index 39103de869a4..5dd895fcfced 100644 --- a/colossalai/kernel/cpu_adam_loader.py +++ b/colossalai/kernel/cpu_adam_loader.py @@ -10,7 +10,7 @@ class CPUAdamLoader(BaseKernelLoader): # init cpu_adam = CPUAdamLoader().load() cpu_adam_op = cpu_adam.CPUAdamOptimizer( - alpha, beta1, beta2, epsilon, weight_decay, adamw_mode + alpha, beta1, beta2, epsilon, weight_decay, adamw_mode, ) ... # optim step From d0bc49f5c5764d54fefef598987e1b5ac6fe7e75 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 7 Dec 2023 14:44:41 +0800 Subject: [PATCH 06/21] update commit --- colossalai/kernel/extensions/utils.py | 91 +++++++++++++ colossalai/kernel/flash_attention_loader.py | 139 ++++++++++++++++++++ 2 files changed, 230 insertions(+) create mode 100644 colossalai/kernel/flash_attention_loader.py diff --git a/colossalai/kernel/extensions/utils.py b/colossalai/kernel/extensions/utils.py index 3f75f952d57b..7e325d47b32e 100644 --- a/colossalai/kernel/extensions/utils.py +++ b/colossalai/kernel/extensions/utils.py @@ -1,3 +1,4 @@ +import enum import os import re import subprocess @@ -227,3 +228,93 @@ def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]: if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: return nvcc_extra_args + ["--threads", "4"] return nvcc_extra_args + + +from dataclasses import dataclass +from typing import Iterable, Tuple + +import torch +import torch.nn.functional as F +from einops import rearrange + +from colossalai.utils.device import get_current_device + + +class Unpad(torch.autograd.Function): + """ + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py + """ + + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): + ctx.save_for_backward(indices) + # [b, s, ...] + assert tensor.ndim >= 3 + ctx.bsz = tensor.shape[0] + out = rearrange(tensor, "b s ... -> (b s) ...") + ctx.shape = out.shape + # [ntokens, ...] + return out[indices] + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # [ntokens, ...] + grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) + grad[indices] = grad_output + grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz) + # [b, s, ...] + return grad, None + + +class Repad(torch.autograd.Function): + """ + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py + """ + + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): + ctx.save_for_backward(indices) + # [ntokens, ...] + tensor = tensor + out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) + # [b*s, ...] + out[indices] = tensor + return out + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # [b*s, ...] + grad = grad_output[indices] + # [ntokens, ...] + return grad, None, None, None + + +@dataclass +class SeqLenInfo: + seqlens: Iterable[int] = None + indices: torch.Tensor = None + max_seqlen: int = None + cu_seqlens: torch.Tensor = None + + @staticmethod + def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()): + if attn_mask is not None: + indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) + seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() + else: + batch_size, tgt_len = size[0], size[1] + indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device) + seqlens = torch.LongTensor([tgt_len] * batch_size, device=device) + max_seqlen = max(seqlens) + cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device) + return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens) + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 + paddedcausal = 3 diff --git a/colossalai/kernel/flash_attention_loader.py b/colossalai/kernel/flash_attention_loader.py new file mode 100644 index 000000000000..38b6e8fcfcca --- /dev/null +++ b/colossalai/kernel/flash_attention_loader.py @@ -0,0 +1,139 @@ +import math +from typing import Optional + +import torch +from einops import rearrange + +from .base_kernel_loader import BaseKernelLoader +from .cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN +from .cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN +from .extensions.cpu_adam import ArmCPUAdamExtension, X86CPUAdamExtension +from .extensions.utils import AttnMaskType, Repad, SeqLenInfo, Unpad + +if HAS_FLASH_ATTN: + from .cuda_native.mha.flash_attn_2 import flash_attention +if HAS_MEM_EFF_ATTN: + from .cuda_native.mha.mem_eff_attn import mem_eff_attention + + +class FlashAttentionLoader(BaseKernelLoader): + """ + FlashAttention Loader + """ + + def __init__(self): + super().__init__( + extension_map=dict( + arm=ArmCPUAdamExtension, + x86=X86CPUAdamExtension, + ), + supported_device=["cuda", "npu"], + ) + + def fetch_kernel(self, backend: str = None): + if self._is_x86_available(): + kernel = self._extension_map["x86"].fetch() + elif self._is_arm_available(): + kernel = self._extension_map["arm"].fetch() + else: + raise Exception("not supported") + return kernel + + +class ColoAttention(torch.nn.Module): + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): + super().__init__() + assert ( + embed_dim % num_heads == 0 + ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." + if scale is not None: + self.scale = scale + else: + self.scale = 1 / math.sqrt(embed_dim // num_heads) + self.dropout = dropout + + if not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN: + raise Exception("flash attention can not support!") + + self.attn = None + if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None: + self.attn = flash_attention + else: + self.attn = mem_eff_attention + + @staticmethod + def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + return Unpad.apply(tensor, indices) + + @staticmethod + def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: + return Repad.apply(tensor, indices, batch_size, seq_len) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + origin_attn_mask: Optional[torch.Tensor] = None, + attn_mask_type: Optional[AttnMaskType] = None, + bias: Optional[torch.Tensor] = None, + ): + padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 + causal = attn_mask_type is not None and attn_mask_type.value > 1 + + batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] + # unpad + seq_len_info_q = None + seq_len_info_kv = None + if padded: + # bert style, unpad process + assert ( + attn_mask is not None + ), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." + assert attn_mask.dim() == 2, ( + "attention mask is supposed to have shape (batch_size, seq_len), " + + f"but got {attn_mask.dim()} dimensions." + ) + + # bert style + if tgt_len == src_len: + seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) + if batch_size > 1: + query, key, value = self.unpad( + torch.stack([query, key, value], dim=2), seq_len_info_q.indices + ).unbind(dim=1) + else: + query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) + seq_len_info_kv = seq_len_info_q + else: + seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device) + seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) + if batch_size > 1: + query = rearrange(query, "b s ... -> c (b s) ...", c=1) + key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind( + dim=1 + ) + else: + query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) + + out = self.attn( + query, + key, + value, + seq_len_info_q, + seq_len_info_kv, + dropout_p=self.dropout, + scale=self.scale, + causal=causal, + padded=padded, + ) + + # repad + if padded: + if batch_size > 1: + out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len) + out = rearrange(out, "(b s) h d -> b s h d", b=batch_size) + + out = rearrange(out, "b s h d -> b s (h d)") + return out From 0b01dd262abbb38351dd9212f800f6c2d0255c1f Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 7 Dec 2023 16:52:50 +0800 Subject: [PATCH 07/21] update flash --- colossalai/kernel/cpu_adam_loader.py | 4 +- .../cuda_flash_attn_2_extension.py | 99 +++++++++++++++++++ .../cuda_memory_efficient_attn_extension.py | 90 +++++++++++++++++ colossalai/kernel/flash_attention_loader.py | 44 ++++----- 4 files changed, 212 insertions(+), 25 deletions(-) create mode 100644 colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py create mode 100644 colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py diff --git a/colossalai/kernel/cpu_adam_loader.py b/colossalai/kernel/cpu_adam_loader.py index 5dd895fcfced..f4125bf4a464 100644 --- a/colossalai/kernel/cpu_adam_loader.py +++ b/colossalai/kernel/cpu_adam_loader.py @@ -53,9 +53,9 @@ def __init__(self): def fetch_kernel(self): if self._is_x86_available(): - kernel = self._extension_map["x86"].fetch() + kernel = self._extension_map["x86"]().fetch() elif self._is_arm_available(): - kernel = self._extension_map["arm"].fetch() + kernel = self._extension_map["arm"]().fetch() else: raise Exception("not supported") return kernel diff --git a/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py b/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py new file mode 100644 index 000000000000..f2cd3bef6c67 --- /dev/null +++ b/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py @@ -0,0 +1,99 @@ +import warnings +from typing import Optional + +import torch + +from ..base_extension import BaseExtension +from ..utils import SeqLenInfo + + +def is_ampere_or_better_gpu(): + if torch.cuda.is_available(): + device = torch.device("cuda") + properties = torch.cuda.get_device_properties(device) + if properties.major >= 8: # Ampere GPUs or newer + return True + return False + + +# "Check Ampere GPUs or newer" +HAS_FLASH_ATTN = False +if is_ampere_or_better_gpu(): + HAS_FLASH_ATTN = True +else: + warnings.warn("FlashAttention only supports Ampere GPUs or newer.") + HAS_FLASH_ATTN = False +try: + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func + + HAS_FLASH_ATTN = True +except ImportError: + warnings.warn("please install flash_attn from https://github.com/HazyResearch/flash-attention") + HAS_FLASH_ATTN = False + +if HAS_FLASH_ATTN: + + def flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q: SeqLenInfo, + seq_len_info_kv: SeqLenInfo, + bias: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: float = None, + causal: bool = False, + padded: bool = False, + ): + """ + Arguments: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + if padded: + if seq_len_info_kv == None: + seq_len_info_kv = seq_len_info_q + + attn_out = flash_attn_varlen_func( + q, + k, + v, + seq_len_info_q.cu_seqlens, + seq_len_info_kv.cu_seqlens, + seq_len_info_q.max_seqlen, + seq_len_info_kv.max_seqlen, + dropout_p, + scale, + causal, + ) + else: + attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) + return attn_out + + +class CudaFlashAttnExtension(BaseExtension): + def __init__(self) -> None: + super().__init__() + self._is_build_completed = True + + @property + def build_completed(self): + return self._is_build_completed + + def build(self): + pass + + def is_available(self): + return HAS_FLASH_ATTN + + def load(self): + return flash_attention diff --git a/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py b/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py new file mode 100644 index 000000000000..d62e5148a727 --- /dev/null +++ b/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py @@ -0,0 +1,90 @@ +import warnings +from typing import Optional + +import torch + +from ..base_extension import BaseExtension +from ..utils import SeqLenInfo + +HAS_MEM_EFF_ATTN = False +try: + from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention + from xformers.ops.fmha.attn_bias import ( + BlockDiagonalCausalMask, + BlockDiagonalMask, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, + ) + + HAS_MEM_EFF_ATTN = True +except ImportError: + warnings.warn("please install xformers from https://github.com/facebookresearch/xformers") + HAS_MEM_EFF_ATTN = False + +if HAS_MEM_EFF_ATTN: + """ + A general attention module using the flash attention kernels from xformers: + https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha + """ + + allow_alibi = True + for op in MemoryEfficientAttentionCutlassOp: + allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) + + def mem_eff_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q: SeqLenInfo, + seq_len_info_kv: SeqLenInfo, + bias: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: float = None, + causal: bool = False, + padded: bool = False, + ): + attn_bias = None + if padded: # bert style + if not causal: + attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) + else: + attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) + elif causal: # gpt style + attn_bias = LowerTriangularMask() + + if bias is not None: # alibi / relative position embedding + assert allow_alibi, "flash attention with bias is not supported in this system." + assert causal, "attention with bias is only supported for causal attention so far." + attn_bias = attn_bias.add_bias(bias) + + if padded: + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + + out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale) + + # shape: (b*s, n, d) + if padded: + out = out.squeeze(0) + + return out + + +class CudaMemoryEfficentAttnExtension(BaseExtension): + def __init__(self) -> None: + super().__init__() + self._is_build_completed = True + + @property + def build_completed(self): + return self._is_build_completed + + def build(self): + pass + + def is_available(self): + return HAS_MEM_EFF_ATTN + + def load(self): + return mem_eff_attention diff --git a/colossalai/kernel/flash_attention_loader.py b/colossalai/kernel/flash_attention_loader.py index 38b6e8fcfcca..2aa69c5e7b9c 100644 --- a/colossalai/kernel/flash_attention_loader.py +++ b/colossalai/kernel/flash_attention_loader.py @@ -5,16 +5,10 @@ from einops import rearrange from .base_kernel_loader import BaseKernelLoader -from .cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN -from .cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN -from .extensions.cpu_adam import ArmCPUAdamExtension, X86CPUAdamExtension +from .extensions.flash_attention.cuda_flash_attn_2_extension import CudaFlashAttnExtension +from .extensions.flash_attention.cuda_memory_efficient_attn_extension import CudaMemoryEfficentAttnExtension from .extensions.utils import AttnMaskType, Repad, SeqLenInfo, Unpad -if HAS_FLASH_ATTN: - from .cuda_native.mha.flash_attn_2 import flash_attention -if HAS_MEM_EFF_ATTN: - from .cuda_native.mha.mem_eff_attn import mem_eff_attention - class FlashAttentionLoader(BaseKernelLoader): """ @@ -24,18 +18,23 @@ class FlashAttentionLoader(BaseKernelLoader): def __init__(self): super().__init__( extension_map=dict( - arm=ArmCPUAdamExtension, - x86=X86CPUAdamExtension, + cuda_flash_attn=CudaFlashAttnExtension, + cuda_memory_efficent_attn=CudaMemoryEfficentAttnExtension, ), supported_device=["cuda", "npu"], ) def fetch_kernel(self, backend: str = None): - if self._is_x86_available(): - kernel = self._extension_map["x86"].fetch() - elif self._is_arm_available(): - kernel = self._extension_map["arm"].fetch() - else: + if backend is not None: + return self._extension_map[backend].fetch() + + kernel = None + for _, kernel_extension in self._extension_map.items(): + ext = kernel_extension() + if ext.is_available(): + kernel = ext.fetch() + break + if kernel is None: raise Exception("not supported") return kernel @@ -52,14 +51,7 @@ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=N self.scale = 1 / math.sqrt(embed_dim // num_heads) self.dropout = dropout - if not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN: - raise Exception("flash attention can not support!") - - self.attn = None - if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None: - self.attn = flash_attention - else: - self.attn = mem_eff_attention + self.attn = FlashAttentionLoader().fetch_kernel() @staticmethod def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: @@ -79,6 +71,12 @@ def forward( attn_mask_type: Optional[AttnMaskType] = None, bias: Optional[torch.Tensor] = None, ): + # if flash attention is not applicable, switch to memory effcient attention + if self.attn.__name__ == "flash_attention" and ( + query.dtype not in [torch.float16, torch.bfloat16] or bias != None + ): + self.attn = FlashAttentionLoader().fetch_kernel(backend="cuda_mem_eff_attn") + padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 causal = attn_mask_type is not None and attn_mask_type.value > 1 From eebeab38e71e2962f92596cea8dfe1f8a5911abc Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 7 Dec 2023 16:53:27 +0800 Subject: [PATCH 08/21] update memory efficient --- colossalai/kernel/flash_attention_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/kernel/flash_attention_loader.py b/colossalai/kernel/flash_attention_loader.py index 2aa69c5e7b9c..27487503599e 100644 --- a/colossalai/kernel/flash_attention_loader.py +++ b/colossalai/kernel/flash_attention_loader.py @@ -75,7 +75,7 @@ def forward( if self.attn.__name__ == "flash_attention" and ( query.dtype not in [torch.float16, torch.bfloat16] or bias != None ): - self.attn = FlashAttentionLoader().fetch_kernel(backend="cuda_mem_eff_attn") + self.attn = FlashAttentionLoader().fetch_kernel(backend="cuda_memory_efficent_attn") padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 causal = attn_mask_type is not None and attn_mask_type.value > 1 From 86064835e38fb27b03689c9760f38ed49d419c31 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 7 Dec 2023 17:00:26 +0800 Subject: [PATCH 09/21] update flash attn --- colossalai/kernel/extensions/flash_attention/__init__.py | 4 ++++ colossalai/kernel/flash_attention_loader.py | 3 +-- 2 files changed, 5 insertions(+), 2 deletions(-) create mode 100644 colossalai/kernel/extensions/flash_attention/__init__.py diff --git a/colossalai/kernel/extensions/flash_attention/__init__.py b/colossalai/kernel/extensions/flash_attention/__init__.py new file mode 100644 index 000000000000..c8487601e462 --- /dev/null +++ b/colossalai/kernel/extensions/flash_attention/__init__.py @@ -0,0 +1,4 @@ +from .cuda_flash_attn_2_extension import CudaFlashAttnExtension +from .cuda_memory_efficient_attn_extension import CudaMemoryEfficentAttnExtension + +__all__ = ["CudaFlashAttnExtension", "CudaMemoryEfficentAttnExtension"] diff --git a/colossalai/kernel/flash_attention_loader.py b/colossalai/kernel/flash_attention_loader.py index 27487503599e..c9d229be9bb5 100644 --- a/colossalai/kernel/flash_attention_loader.py +++ b/colossalai/kernel/flash_attention_loader.py @@ -5,8 +5,7 @@ from einops import rearrange from .base_kernel_loader import BaseKernelLoader -from .extensions.flash_attention.cuda_flash_attn_2_extension import CudaFlashAttnExtension -from .extensions.flash_attention.cuda_memory_efficient_attn_extension import CudaMemoryEfficentAttnExtension +from .extensions.flash_attention import CudaFlashAttnExtension, CudaMemoryEfficentAttnExtension from .extensions.utils import AttnMaskType, Repad, SeqLenInfo, Unpad From d2c6e2325c8f251f1f23f3cf5c0160f9d5f91884 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 7 Dec 2023 17:43:36 +0800 Subject: [PATCH 10/21] update flash attention loader --- .../extensions/flash_attention/__init__.py | 9 +- .../cuda_flash_attn_2_extension.py | 22 +-- .../cuda_memory_efficient_attn_extension.py | 6 +- .../npu_sdpa_attn_extension.py | 61 ++++++++ .../npu_triangle_attn_extension.py | 142 ++++++++++++++++++ colossalai/kernel/flash_attention_loader.py | 31 ++-- 6 files changed, 249 insertions(+), 22 deletions(-) create mode 100644 colossalai/kernel/extensions/flash_attention/npu_sdpa_attn_extension.py create mode 100644 colossalai/kernel/extensions/flash_attention/npu_triangle_attn_extension.py diff --git a/colossalai/kernel/extensions/flash_attention/__init__.py b/colossalai/kernel/extensions/flash_attention/__init__.py index c8487601e462..c440622146eb 100644 --- a/colossalai/kernel/extensions/flash_attention/__init__.py +++ b/colossalai/kernel/extensions/flash_attention/__init__.py @@ -1,4 +1,11 @@ from .cuda_flash_attn_2_extension import CudaFlashAttnExtension from .cuda_memory_efficient_attn_extension import CudaMemoryEfficentAttnExtension +from .npu_sdpa_attn_extension import NpuSpdaAttnExtension +from .npu_triangle_attn_extension import NpuTriangleAttnExtension -__all__ = ["CudaFlashAttnExtension", "CudaMemoryEfficentAttnExtension"] +__all__ = [ + "CudaFlashAttnExtension", + "CudaMemoryEfficentAttnExtension", + "NpuSpdaAttnExtension", + "NpuTriangleAttnExtension", +] diff --git a/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py b/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py index f2cd3bef6c67..e0dbd544d0a3 100644 --- a/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py +++ b/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py @@ -8,6 +8,7 @@ def is_ampere_or_better_gpu(): + # Check Ampere GPUs or newer if torch.cuda.is_available(): device = torch.device("cuda") properties = torch.cuda.get_device_properties(device) @@ -16,20 +17,18 @@ def is_ampere_or_better_gpu(): return False -# "Check Ampere GPUs or newer" HAS_FLASH_ATTN = False +ERROR_MSG = None if is_ampere_or_better_gpu(): - HAS_FLASH_ATTN = True + try: + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func + + HAS_FLASH_ATTN = True + except ImportError: + ERROR_MSG = "ImportError: please install flash_attn from https://github.com/HazyResearch/flash-attention" else: - warnings.warn("FlashAttention only supports Ampere GPUs or newer.") - HAS_FLASH_ATTN = False -try: - from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func + ERROR_MSG = "ImportError: FlashAttention only supports Ampere GPUs or newer." - HAS_FLASH_ATTN = True -except ImportError: - warnings.warn("please install flash_attn from https://github.com/HazyResearch/flash-attention") - HAS_FLASH_ATTN = False if HAS_FLASH_ATTN: @@ -39,6 +38,7 @@ def flash_attention( v: torch.Tensor, seq_len_info_q: SeqLenInfo, seq_len_info_kv: SeqLenInfo, + origin_attn_mask: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, dropout_p: float = 0.0, scale: float = None, @@ -93,6 +93,8 @@ def build(self): pass def is_available(self): + if HAS_FLASH_ATTN == False: + warnings.warn(ERROR_MSG) return HAS_FLASH_ATTN def load(self): diff --git a/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py b/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py index d62e5148a727..7be7824f911a 100644 --- a/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py +++ b/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py @@ -18,8 +18,7 @@ HAS_MEM_EFF_ATTN = True except ImportError: - warnings.warn("please install xformers from https://github.com/facebookresearch/xformers") - HAS_MEM_EFF_ATTN = False + pass if HAS_MEM_EFF_ATTN: """ @@ -37,6 +36,7 @@ def mem_eff_attention( v: torch.Tensor, seq_len_info_q: SeqLenInfo, seq_len_info_kv: SeqLenInfo, + origin_attn_mask: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, dropout_p: float = 0.0, scale: float = None, @@ -84,6 +84,8 @@ def build(self): pass def is_available(self): + if HAS_MEM_EFF_ATTN == False: + warnings.warn("ImportError: please install xformers from https://github.com/facebookresearch/xformers") return HAS_MEM_EFF_ATTN def load(self): diff --git a/colossalai/kernel/extensions/flash_attention/npu_sdpa_attn_extension.py b/colossalai/kernel/extensions/flash_attention/npu_sdpa_attn_extension.py new file mode 100644 index 000000000000..d5af63653b8b --- /dev/null +++ b/colossalai/kernel/extensions/flash_attention/npu_sdpa_attn_extension.py @@ -0,0 +1,61 @@ +import torch +from einops import rearrange + +from ..base_extension import BaseExtension + + +def npu_sdpa_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q=None, + seq_len_info_kv=None, + origin_attn_mask: torch.Tensor = None, + dropout_p: float = 0.0, + scale: float = 1.0, + causal=None, + padded=None, +): + """ + The scaled dot product attention. + + Arguments: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + scale: float. The scaling of QK^T before applying softmax. + Default to 1. + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)] + output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=origin_attn_mask, + dropout_p=dropout_p, + is_causal=origin_attn_mask is None, + scale=scale, + ) + output = rearrange(output, "b h s d -> b s (h d)") + return output + + +class NpuSpdaAttnExtension(BaseExtension): + def __init__(self) -> None: + super().__init__() + self._is_build_completed = True + + @property + def build_completed(self): + return self._is_build_completed + + def build(self): + pass + + def load(self): + return npu_sdpa_attention diff --git a/colossalai/kernel/extensions/flash_attention/npu_triangle_attn_extension.py b/colossalai/kernel/extensions/flash_attention/npu_triangle_attn_extension.py new file mode 100644 index 000000000000..5401b824051d --- /dev/null +++ b/colossalai/kernel/extensions/flash_attention/npu_triangle_attn_extension.py @@ -0,0 +1,142 @@ +# coding=utf-8 +# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import torch +from einops import rearrange + +from ..base_extension import BaseExtension + +HAS_NPU_TRIANGLE_ATTENTION = False +try: + from torch_npu import npu_confusion_transpose, npu_scaled_masked_softmax + + HAS_NPU_TRIANGLE_ATTENTION = True +except ImportError: + pass + + +if HAS_NPU_TRIANGLE_ATTENTION: + + def npu_triangle_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q=None, + seq_len_info_kv=None, + origin_attn_mask: torch.Tensor = None, + dropout_p: float = 0.0, + scale: float = 1.0, + causal=None, + padded=None, + block_size=512, + ): + """ + The triangle attention reduces the attention calculation of the mask + part by dividing the q, k, and v matrices into blocks + + Arguments: + block_size: The size of the inverted triangle block, the default is 512, + the smaller the block_size, the more calculations will be reduced, + but the number of small operators will be increased + masked_softmax_func: mask function to be applied. + dropout_func: dropout function to be applied. + """ + + def compute_attn(q_layer, k_layer, v_layer, mask_tmp): + # [b, hn, q_size, hd] * [b, hn, hd, kv_size] -> [b, hn, q_size, kv_size] + cur_sim = torch.matmul(q_layer, k_layer) + attention_probs = npu_scaled_masked_softmax(cur_sim, mask_tmp) + # attention dropout + if dropout_p > 0: + attention_probs = torch.nn.functional.dropout( + attention_probs, p=dropout_p, training=attention_probs.require_grad + ) + # [b, hn, q_size, kv_size] * [b, hn, kv_size, hd] -> [b, hn, q_size, hd] + context_layer_tmp = torch.matmul(attention_probs, v_layer) + return context_layer_tmp + + q, k, v = [rearrange(x, "b s h d -> b h s d") for x in (q, k, v)] + origin_attn_mask = origin_attn_mask.to(torch.bool) + # input shape: [b, hn, sq, hd] + bsz, head_num, sequence_len, head_dim = k.shape + sparse_groups = sequence_len // block_size + # Determine whether blocks size can be divided by sequence_length + divisible_flag = sequence_len == block_size * sparse_groups + k = k.transpose(2, 3).contiguous() + if divisible_flag: + q_tmp_layers = torch.chunk(q, sparse_groups, 2) + k_tmp_layers = torch.chunk(k, sparse_groups, 3) + v_tmp_layers = torch.chunk(v, sparse_groups, 2) + else: + seq_tmp = block_size * sparse_groups + q_last = q[:, :, seq_tmp:, :].contiguous() + mask_last = origin_attn_mask[:, :, seq_tmp:, :].contiguous() + q_tmp_layers = torch.chunk(q[:, :, :seq_tmp, :], sparse_groups, 2) + k_tmp_layers = torch.chunk(k[:, :, :, :seq_tmp], sparse_groups, 3) + v_tmp_layers = torch.chunk(v[:, :, :seq_tmp, :], sparse_groups, 2) + context_list_tmp, k_tmp, v_tmp = [], (), () + for i in range(sparse_groups): + # compute slice shape of q k v for each loop + q_begin, q_end = i * block_size, (i + 1) * block_size + kv_begin, kv_end = 0, (i + 1) * block_size + q_tmp = q_tmp_layers[i] + # slice k and v + if i == 0: + k_tmp = k_tmp_layers[i].contiguous() + v_tmp = v_tmp_layers[i].contiguous() + else: + k_tmp = torch.cat((k_tmp, k_tmp_layers[i]), -1).contiguous() + v_tmp = torch.cat((v_tmp, v_tmp_layers[i]), -2).contiguous() + + mask_tmp = origin_attn_mask[:, :, q_begin:q_end, kv_begin:kv_end].contiguous() + context_layer_tmp = compute_attn(q_tmp, k_tmp, v_tmp, mask_tmp) + context_list_tmp.append(context_layer_tmp) + + if not divisible_flag: + # circumstances that cannot be divisible + context_layer_tmp = compute_attn(q_last, k, v, mask_last) + context_list_tmp.append(context_layer_tmp) + context_layer = torch.cat(context_list_tmp, 2) + new_context_layer_shape = (bsz, sequence_len, head_num * head_dim) + context_layer = npu_confusion_transpose(context_layer, [0, 2, 1, 3], [*new_context_layer_shape], True) + # ========================= + # Context layer. [b, sq, hp] + # ========================= + return context_layer + + +class NpuTriangleAttnExtension(BaseExtension): + def __init__(self) -> None: + super().__init__() + self._is_build_completed = True + + @property + def build_completed(self): + return self._is_build_completed + + def build(self): + pass + + def is_available(self): + if HAS_NPU_TRIANGLE_ATTENTION == False: + warnings.warn( + "ImportError: please install latest torch_npu with 'npu_confusion_transpose' and 'npu_scaled_masked_softmax' api." + ) + return HAS_NPU_TRIANGLE_ATTENTION + + def load(self): + return npu_triangle_attention diff --git a/colossalai/kernel/flash_attention_loader.py b/colossalai/kernel/flash_attention_loader.py index c9d229be9bb5..6ae0ae334380 100644 --- a/colossalai/kernel/flash_attention_loader.py +++ b/colossalai/kernel/flash_attention_loader.py @@ -5,7 +5,12 @@ from einops import rearrange from .base_kernel_loader import BaseKernelLoader -from .extensions.flash_attention import CudaFlashAttnExtension, CudaMemoryEfficentAttnExtension +from .extensions.flash_attention import ( + CudaFlashAttnExtension, + CudaMemoryEfficentAttnExtension, + NpuSpdaAttnExtension, + NpuTriangleAttnExtension, +) from .extensions.utils import AttnMaskType, Repad, SeqLenInfo, Unpad @@ -19,6 +24,8 @@ def __init__(self): extension_map=dict( cuda_flash_attn=CudaFlashAttnExtension, cuda_memory_efficent_attn=CudaMemoryEfficentAttnExtension, + npu_spda_attn=NpuSpdaAttnExtension, + npu_triangle_attn=NpuTriangleAttnExtension, ), supported_device=["cuda", "npu"], ) @@ -28,13 +35,18 @@ def fetch_kernel(self, backend: str = None): return self._extension_map[backend].fetch() kernel = None - for _, kernel_extension in self._extension_map.items(): - ext = kernel_extension() - if ext.is_available(): - kernel = ext.fetch() - break + if self._is_cuda_available(): + if CudaFlashAttnExtension().is_available(): + kernel = CudaFlashAttnExtension().fetch() + elif CudaMemoryEfficentAttnExtension.is_available(): + kernel = CudaMemoryEfficentAttnExtension().fetch() + elif self._is_npu_available(): + if NpuTriangleAttnExtension().is_available(): + kernel = NpuTriangleAttnExtension().fetch() + else: + kernel = NpuSpdaAttnExtension().fetch() if kernel is None: - raise Exception("not supported") + raise Exception("No extension for flash attention is supported") return kernel @@ -118,8 +130,9 @@ def forward( query, key, value, - seq_len_info_q, - seq_len_info_kv, + seq_len_info_q=seq_len_info_q, + seq_len_info_kv=seq_len_info_kv, + origin_attn_mask=origin_attn_mask, dropout_p=self.dropout, scale=self.scale, causal=causal, From 0cb447d5d939a1bc9fd2d1668705e3fc2ecb1bf4 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 7 Dec 2023 17:57:52 +0800 Subject: [PATCH 11/21] update api --- colossalai/kernel/__init__.py | 5 +++++ colossalai/shardformer/layer/utils.py | 21 ++------------------- colossalai/shardformer/modeling/blip2.py | 2 +- colossalai/shardformer/modeling/chatglm2.py | 2 +- colossalai/shardformer/modeling/gpt2.py | 2 +- colossalai/shardformer/modeling/llama.py | 14 ++++++++++---- colossalai/shardformer/modeling/opt.py | 2 +- colossalai/shardformer/modeling/vit.py | 2 +- colossalai/shardformer/modeling/whisper.py | 2 +- tests/test_utils/test_flash_attention.py | 3 +-- 10 files changed, 24 insertions(+), 31 deletions(-) diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py index 6f5b072aa4ab..8a0062f7a8b0 100644 --- a/colossalai/kernel/__init__.py +++ b/colossalai/kernel/__init__.py @@ -1,9 +1,14 @@ from .cpu_adam_loader import CPUAdamLoader from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention +from .extensions.utils import AttnMaskType +from .flash_attention_loader import ColoAttention, FlashAttentionLoader __all__ = [ "LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention", "CPUAdamLoader", + "FlashAttentionLoader", + "ColoAttention", + "AttnMaskType", ] diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 55683b227be9..96fd3bd7bd7b 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -6,7 +6,8 @@ from torch import nn from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed import ProcessGroup, get_world_size -from colossalai.utils.device import get_current_device, get_rng_state, set_rng_state, manual_seed + +from colossalai.utils.device import get_current_device, get_rng_state, manual_seed, set_rng_state class SeqParallelUtils: @@ -280,21 +281,3 @@ def create_randomizer_with_offset( Randomizer.increment_index() return Randomizer(seed=base_seed) - - -def get_attention_kernel(): - """ - Get the attention kernel based on the device type. - """ - from colossalai.kernel.cuda_native import AttnMaskType - - if torch.cuda.is_available(): - from colossalai.kernel.cuda_native import ColoAttention as AttentionKernel - else: - try: - torch.npu.is_available() - from colossalai.kernel.npu import NPUColoAttention as AttentionKernel - except: - raise Exception("No available device for attention kernel!") - - return AttnMaskType, AttentionKernel diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py index 00b2037fbdc8..3522264ad9a0 100644 --- a/colossalai/shardformer/modeling/blip2.py +++ b/colossalai/shardformer/modeling/blip2.py @@ -62,7 +62,7 @@ def forward( def get_blip2_flash_attention_forward(): from transformers.models.blip_2.modeling_blip_2 import Blip2Attention - from colossalai.kernel.cuda_native import ColoAttention + from colossalai.kernel import ColoAttention def forward( self: Blip2Attention, diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index c8a311df7c6d..0e469b7dd0c4 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -14,7 +14,7 @@ def get_flash_core_attention_forward(): - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.kernel import AttnMaskType, ColoAttention from .chatglm2_6b.modeling_chatglm import CoreAttention diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 8f456353742c..9ab51b90ea33 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -719,7 +719,7 @@ def gpt2_for_sequence_classification_forward( def get_gpt2_flash_attention_forward(): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.kernel import AttnMaskType, ColoAttention def split_heads(tensor, num_heads, attn_head_size): """ diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index c3de197c4354..9d02e1376514 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -1,5 +1,5 @@ import warnings -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -12,14 +12,15 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.layer.utils import get_attention_kernel try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask + LATEST_VERSION = True except ImportError: LATEST_VERSION = False + class LlamaPipelineForwards: """ This class serves as a micro library for forward function substitution of Llama models @@ -405,7 +406,7 @@ def llama_for_sequence_classification_forward( def get_llama_flash_attention_forward(): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - AttnMaskType, ColoAttention = get_attention_kernel() + from colossalai.kernel import AttnMaskType, ColoAttention llama_version = 2 try: @@ -469,7 +470,12 @@ def forward( attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) attn_output = attention( - query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type, origin_attn_mask=attention_mask, + query_states, + key_states, + value_states, + attn_mask=flash_attention_mask, + attn_mask_type=attn_mask_type, + origin_attn_mask=attention_mask, ) attn_output = self.o_proj(attn_output) diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 71f2ca3353bc..625b78bd8775 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -514,7 +514,7 @@ def opt_for_question_answering_forward( def get_opt_flash_attention_forward(): from transformers.models.opt.modeling_opt import OPTAttention - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.kernel import AttnMaskType, ColoAttention def forward( self: OPTAttention, diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 5a50e7379cdc..ca3574253e0e 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -336,7 +336,7 @@ def pp_forward( def get_vit_flash_self_attention_forward(): from transformers.models.vit.modeling_vit import ViTSelfAttention - from colossalai.kernel.cuda_native import ColoAttention + from colossalai.kernel import ColoAttention def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor: new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size) diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 9827d4801f8d..f67f6cd63141 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -26,7 +26,7 @@ def get_whisper_flash_attention_forward(): from transformers.models.whisper.modeling_whisper import WhisperAttention - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.kernel import AttnMaskType, ColoAttention def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous() diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index a5c465ba0b07..ea6cc666ea01 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -9,8 +9,7 @@ from colossalai.testing import clear_cache_before_run, parameterize if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN: - from colossalai.kernel.cuda_native import ColoAttention - from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType + from colossalai.kernel import AttnMaskType, ColoAttention DTYPE = [torch.float16, torch.bfloat16, torch.float32] From 1a7c9ce453dac059730c5bbdef3185737301f383 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 7 Dec 2023 18:02:18 +0800 Subject: [PATCH 12/21] fix --- colossalai/kernel/flash_attention_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/kernel/flash_attention_loader.py b/colossalai/kernel/flash_attention_loader.py index 6ae0ae334380..9822f94fe5b7 100644 --- a/colossalai/kernel/flash_attention_loader.py +++ b/colossalai/kernel/flash_attention_loader.py @@ -32,7 +32,7 @@ def __init__(self): def fetch_kernel(self, backend: str = None): if backend is not None: - return self._extension_map[backend].fetch() + return self._extension_map[backend]().fetch() kernel = None if self._is_cuda_available(): From 982474b9c7e80578f1f7490bda887377a0e6073c Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 7 Dec 2023 23:56:47 +0800 Subject: [PATCH 13/21] update doc --- colossalai/kernel/flash_attention_loader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/kernel/flash_attention_loader.py b/colossalai/kernel/flash_attention_loader.py index 9822f94fe5b7..326224b7a7f2 100644 --- a/colossalai/kernel/flash_attention_loader.py +++ b/colossalai/kernel/flash_attention_loader.py @@ -17,6 +17,7 @@ class FlashAttentionLoader(BaseKernelLoader): """ FlashAttention Loader + """ def __init__(self): @@ -145,5 +146,6 @@ def forward( out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len) out = rearrange(out, "(b s) h d -> b s h d", b=batch_size) - out = rearrange(out, "b s h d -> b s (h d)") + if len(out.shape) == 4: + out = rearrange(out, "b s h d -> b s (h d)") return out From 150257339e1727abd62f727f7f1bbbf18d453090 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Fri, 8 Dec 2023 00:14:29 +0800 Subject: [PATCH 14/21] update example time limit --- .github/workflows/example_check_on_dispatch.yml | 2 +- .github/workflows/example_check_on_pr.yml | 2 +- .github/workflows/example_check_on_schedule.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/example_check_on_dispatch.yml b/.github/workflows/example_check_on_dispatch.yml index 9d3bd9a48235..011a0ae036f2 100644 --- a/.github/workflows/example_check_on_dispatch.yml +++ b/.github/workflows/example_check_on_dispatch.yml @@ -47,7 +47,7 @@ jobs: container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ - timeout-minutes: 10 + timeout-minutes: 15 steps: - name: 📚 Checkout uses: actions/checkout@v3 diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 5934704f4102..608ae863fdf1 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -79,7 +79,7 @@ jobs: container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ - timeout-minutes: 10 + timeout-minutes: 15 concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }} cancel-in-progress: true diff --git a/.github/workflows/example_check_on_schedule.yml b/.github/workflows/example_check_on_schedule.yml index 5ed128c3ebc5..4fcd1e3a9ac1 100644 --- a/.github/workflows/example_check_on_schedule.yml +++ b/.github/workflows/example_check_on_schedule.yml @@ -35,7 +35,7 @@ jobs: matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 - timeout-minutes: 10 + timeout-minutes: 15 steps: - name: 📚 Checkout uses: actions/checkout@v3 From e53408a56347338cee1caf95f005dd30fcf4ebb2 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Fri, 8 Dec 2023 00:15:34 +0800 Subject: [PATCH 15/21] reverse change --- .github/workflows/example_check_on_dispatch.yml | 2 +- .github/workflows/example_check_on_pr.yml | 2 +- .github/workflows/example_check_on_schedule.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/example_check_on_dispatch.yml b/.github/workflows/example_check_on_dispatch.yml index 011a0ae036f2..9d3bd9a48235 100644 --- a/.github/workflows/example_check_on_dispatch.yml +++ b/.github/workflows/example_check_on_dispatch.yml @@ -47,7 +47,7 @@ jobs: container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ - timeout-minutes: 15 + timeout-minutes: 10 steps: - name: 📚 Checkout uses: actions/checkout@v3 diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 608ae863fdf1..5934704f4102 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -79,7 +79,7 @@ jobs: container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ - timeout-minutes: 15 + timeout-minutes: 10 concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }} cancel-in-progress: true diff --git a/.github/workflows/example_check_on_schedule.yml b/.github/workflows/example_check_on_schedule.yml index 4fcd1e3a9ac1..5ed128c3ebc5 100644 --- a/.github/workflows/example_check_on_schedule.yml +++ b/.github/workflows/example_check_on_schedule.yml @@ -35,7 +35,7 @@ jobs: matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 - timeout-minutes: 15 + timeout-minutes: 10 steps: - name: 📚 Checkout uses: actions/checkout@v3 From 2f77365cceda6dcf9ff2ae823608ca7d6b9f82e5 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Fri, 8 Dec 2023 11:21:25 +0800 Subject: [PATCH 16/21] fix doc --- .../extensions/flash_attention/__init__.py | 4 +-- .../npu_sdpa_attn_extension.py | 2 +- colossalai/kernel/flash_attention_loader.py | 32 +++++++++++++++++-- 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/colossalai/kernel/extensions/flash_attention/__init__.py b/colossalai/kernel/extensions/flash_attention/__init__.py index c440622146eb..34fc43392aee 100644 --- a/colossalai/kernel/extensions/flash_attention/__init__.py +++ b/colossalai/kernel/extensions/flash_attention/__init__.py @@ -1,11 +1,11 @@ from .cuda_flash_attn_2_extension import CudaFlashAttnExtension from .cuda_memory_efficient_attn_extension import CudaMemoryEfficentAttnExtension -from .npu_sdpa_attn_extension import NpuSpdaAttnExtension +from .npu_sdpa_attn_extension import NpuSdpaAttnExtension from .npu_triangle_attn_extension import NpuTriangleAttnExtension __all__ = [ "CudaFlashAttnExtension", "CudaMemoryEfficentAttnExtension", - "NpuSpdaAttnExtension", + "NpuSdpaAttnExtension", "NpuTriangleAttnExtension", ] diff --git a/colossalai/kernel/extensions/flash_attention/npu_sdpa_attn_extension.py b/colossalai/kernel/extensions/flash_attention/npu_sdpa_attn_extension.py index d5af63653b8b..be94a7b587ae 100644 --- a/colossalai/kernel/extensions/flash_attention/npu_sdpa_attn_extension.py +++ b/colossalai/kernel/extensions/flash_attention/npu_sdpa_attn_extension.py @@ -45,7 +45,7 @@ def npu_sdpa_attention( return output -class NpuSpdaAttnExtension(BaseExtension): +class NpuSdpaAttnExtension(BaseExtension): def __init__(self) -> None: super().__init__() self._is_build_completed = True diff --git a/colossalai/kernel/flash_attention_loader.py b/colossalai/kernel/flash_attention_loader.py index 326224b7a7f2..d9fee3e82b27 100644 --- a/colossalai/kernel/flash_attention_loader.py +++ b/colossalai/kernel/flash_attention_loader.py @@ -8,7 +8,7 @@ from .extensions.flash_attention import ( CudaFlashAttnExtension, CudaMemoryEfficentAttnExtension, - NpuSpdaAttnExtension, + NpuSdpaAttnExtension, NpuTriangleAttnExtension, ) from .extensions.utils import AttnMaskType, Repad, SeqLenInfo, Unpad @@ -18,6 +18,20 @@ class FlashAttentionLoader(BaseKernelLoader): """ FlashAttention Loader + options: cuda flashh attention, cuda memory effcient attention, npu sdpa attention, npu triangle attention + + Args: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + attn_out: (batch, q_seqlen, nheads, headdim). """ def __init__(self): @@ -25,7 +39,7 @@ def __init__(self): extension_map=dict( cuda_flash_attn=CudaFlashAttnExtension, cuda_memory_efficent_attn=CudaMemoryEfficentAttnExtension, - npu_spda_attn=NpuSpdaAttnExtension, + npu_sdpa_attn=NpuSdpaAttnExtension, npu_triangle_attn=NpuTriangleAttnExtension, ), supported_device=["cuda", "npu"], @@ -45,7 +59,7 @@ def fetch_kernel(self, backend: str = None): if NpuTriangleAttnExtension().is_available(): kernel = NpuTriangleAttnExtension().fetch() else: - kernel = NpuSpdaAttnExtension().fetch() + kernel = NpuSdpaAttnExtension().fetch() if kernel is None: raise Exception("No extension for flash attention is supported") return kernel @@ -83,6 +97,18 @@ def forward( attn_mask_type: Optional[AttnMaskType] = None, bias: Optional[torch.Tensor] = None, ): + """ + ColoAttention + + Args: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + origin_attn_mask: (nheads, q_seqlen, kv_seqlen) + bias: will not be used + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ # if flash attention is not applicable, switch to memory effcient attention if self.attn.__name__ == "flash_attention" and ( query.dtype not in [torch.float16, torch.bfloat16] or bias != None From 13b98b99beb39cc5c26b622b5008f60468736c38 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Fri, 8 Dec 2023 12:00:32 +0800 Subject: [PATCH 17/21] remove useless kernel --- colossalai/kernel/cuda_native/__init__.py | 2 - colossalai/kernel/cuda_native/mha/__init__.py | 3 - .../kernel/cuda_native/mha/flash_attn_2.py | 79 ------------ .../kernel/cuda_native/mha/mem_eff_attn.py | 70 ----------- colossalai/kernel/cuda_native/mha/mha.py | 114 ----------------- colossalai/kernel/cuda_native/mha/utils.py | 82 ------------- .../extensions/flash_attention/__init__.py | 9 +- colossalai/kernel/npu/__init__.py | 3 - colossalai/kernel/npu/mha/__init__.py | 3 - colossalai/kernel/npu/mha/mha.py | 80 ------------ colossalai/kernel/npu/mha/sdpa_attn.py | 41 ------- colossalai/kernel/npu/mha/triangle_attn.py | 115 ------------------ .../openmoe/model/modeling_openmoe.py | 2 +- tests/test_utils/test_flash_attention.py | 3 +- 14 files changed, 8 insertions(+), 598 deletions(-) delete mode 100644 colossalai/kernel/cuda_native/mha/__init__.py delete mode 100644 colossalai/kernel/cuda_native/mha/flash_attn_2.py delete mode 100644 colossalai/kernel/cuda_native/mha/mem_eff_attn.py delete mode 100644 colossalai/kernel/cuda_native/mha/mha.py delete mode 100644 colossalai/kernel/cuda_native/mha/utils.py delete mode 100644 colossalai/kernel/npu/__init__.py delete mode 100644 colossalai/kernel/npu/mha/__init__.py delete mode 100644 colossalai/kernel/npu/mha/mha.py delete mode 100644 colossalai/kernel/npu/mha/sdpa_attn.py delete mode 100644 colossalai/kernel/npu/mha/triangle_attn.py diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py index f8a974b5fb26..0eac28d23e24 100644 --- a/colossalai/kernel/cuda_native/__init__.py +++ b/colossalai/kernel/cuda_native/__init__.py @@ -1,5 +1,4 @@ from .layer_norm import MixedFusedLayerNorm as LayerNorm -from .mha.mha import ColoAttention from .multihead_attention import MultiHeadAttention from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax @@ -8,6 +7,5 @@ "MultiHeadAttention", "FusedScaleMaskSoftmax", "ScaledUpperTriangMaskedSoftmax", - "ColoAttention", "AttnMaskType", ] diff --git a/colossalai/kernel/cuda_native/mha/__init__.py b/colossalai/kernel/cuda_native/mha/__init__.py deleted file mode 100644 index cad36e598d14..000000000000 --- a/colossalai/kernel/cuda_native/mha/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .mha import ColoAttention - -__all__ = ["ColoAttention"] diff --git a/colossalai/kernel/cuda_native/mha/flash_attn_2.py b/colossalai/kernel/cuda_native/mha/flash_attn_2.py deleted file mode 100644 index de2ccaa4947f..000000000000 --- a/colossalai/kernel/cuda_native/mha/flash_attn_2.py +++ /dev/null @@ -1,79 +0,0 @@ -import warnings -from typing import Optional - -import torch - - -def is_ampere_or_better_gpu(): - if torch.cuda.is_available(): - device = torch.device("cuda") - properties = torch.cuda.get_device_properties(device) - if properties.major >= 8: # Ampere GPUs or newer - return True - return False - - -# "Check Ampere GPUs or newer" -HAS_FLASH_ATTN = False -if is_ampere_or_better_gpu(): - HAS_FLASH_ATTN = True -else: - warnings.warn("FlashAttention only supports Ampere GPUs or newer.") - HAS_FLASH_ATTN = False -try: - from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func - - HAS_FLASH_ATTN = True -except ImportError: - warnings.warn("please install flash_attn from https://github.com/HazyResearch/flash-attention") - HAS_FLASH_ATTN = False - -if HAS_FLASH_ATTN: - - from .utils import SeqLenInfo - - def flash_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seq_len_info_q: SeqLenInfo, - seq_len_info_kv: SeqLenInfo, - bias: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - scale: float = None, - causal: bool = False, - padded: bool = False, - ): - """ - Arguments: - q: (batch, q_seqlen, nheads, headdim) - k: (batch, kv_seqlen, nheads, headdim) - v: (batch, kv_seqlen, nheads, headdim) - batch_size: int. - seq_len: int. - dropout_p: float. Dropout probability. - sm_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - Return: - attn_out: (batch, q_seqlen, nheads, headdim). - """ - if padded: - if seq_len_info_kv == None: - seq_len_info_kv = seq_len_info_q - - attn_out = flash_attn_varlen_func( - q, - k, - v, - seq_len_info_q.cu_seqlens, - seq_len_info_kv.cu_seqlens, - seq_len_info_q.max_seqlen, - seq_len_info_kv.max_seqlen, - dropout_p, - scale, - causal, - ) - else: - attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) - return attn_out diff --git a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py deleted file mode 100644 index 649e74d61bab..000000000000 --- a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py +++ /dev/null @@ -1,70 +0,0 @@ -import warnings - -HAS_MEM_EFF_ATTN = False -try: - from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention - from xformers.ops.fmha.attn_bias import ( - BlockDiagonalCausalMask, - BlockDiagonalMask, - LowerTriangularMask, - LowerTriangularMaskWithTensorBias, - ) - - HAS_MEM_EFF_ATTN = True -except ImportError: - warnings.warn("please install xformers from https://github.com/facebookresearch/xformers") - HAS_MEM_EFF_ATTN = False - -if HAS_MEM_EFF_ATTN: - """ - A general attention module using the flash attention kernels from xformers: - https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha - """ - from typing import Optional - - import torch - - from .utils import SeqLenInfo - - allow_alibi = True - for op in MemoryEfficientAttentionCutlassOp: - allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) - - def mem_eff_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seq_len_info_q: SeqLenInfo, - seq_len_info_kv: SeqLenInfo, - bias: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - scale: float = None, - causal: bool = False, - padded: bool = False, - ): - attn_bias = None - if padded: # bert style - if not causal: - attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) - else: - attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) - elif causal: # gpt style - attn_bias = LowerTriangularMask() - - if bias is not None: # alibi / relative position embedding - assert allow_alibi, "flash attention with bias is not supported in this system." - assert causal, "attention with bias is only supported for causal attention so far." - attn_bias = attn_bias.add_bias(bias) - - if padded: - q = q.unsqueeze(0) - k = k.unsqueeze(0) - v = v.unsqueeze(0) - - out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale) - - # shape: (b*s, n, d) - if padded: - out = out.squeeze(0) - - return out diff --git a/colossalai/kernel/cuda_native/mha/mha.py b/colossalai/kernel/cuda_native/mha/mha.py deleted file mode 100644 index b56d37cf026e..000000000000 --- a/colossalai/kernel/cuda_native/mha/mha.py +++ /dev/null @@ -1,114 +0,0 @@ -import math -from typing import Optional - -import torch -from einops import rearrange - -from ..scaled_softmax import AttnMaskType -from .flash_attn_2 import HAS_FLASH_ATTN -from .mem_eff_attn import HAS_MEM_EFF_ATTN -from .utils import Repad, SeqLenInfo, Unpad - -if HAS_FLASH_ATTN: - from .flash_attn_2 import flash_attention -if HAS_MEM_EFF_ATTN: - from .mem_eff_attn import mem_eff_attention - - -class ColoAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): - super().__init__() - assert ( - embed_dim % num_heads == 0 - ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." - if scale is not None: - self.scale = scale - else: - self.scale = 1 / math.sqrt(embed_dim // num_heads) - self.dropout = dropout - - if not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN: - raise Exception("flash attention can not support!") - - @staticmethod - def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: - return Unpad.apply(tensor, indices) - - @staticmethod - def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: - return Repad.apply(tensor, indices, batch_size, seq_len) - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - origin_attn_mask: Optional[torch.Tensor] = None, - attn_mask_type: Optional[AttnMaskType] = None, - bias: Optional[torch.Tensor] = None, - ): - attn = None - if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None: - attn = flash_attention - else: - attn = mem_eff_attention - - padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 - causal = attn_mask_type is not None and attn_mask_type.value > 1 - - batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] - # unpad - seq_len_info_q = None - seq_len_info_kv = None - if padded: - # bert style, unpad process - assert ( - attn_mask is not None - ), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." - assert attn_mask.dim() == 2, ( - "attention mask is supposed to have shape (batch_size, seq_len), " - + f"but got {attn_mask.dim()} dimensions." - ) - - # bert style - if tgt_len == src_len: - seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) - if batch_size > 1: - query, key, value = self.unpad( - torch.stack([query, key, value], dim=2), seq_len_info_q.indices - ).unbind(dim=1) - else: - query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) - seq_len_info_kv = seq_len_info_q - else: - seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device) - seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) - if batch_size > 1: - query = rearrange(query, "b s ... -> c (b s) ...", c=1) - key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind( - dim=1 - ) - else: - query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) - - out = attn( - query, - key, - value, - seq_len_info_q, - seq_len_info_kv, - dropout_p=self.dropout, - scale=self.scale, - causal=causal, - padded=padded, - ) - - # repad - if padded: - if batch_size > 1: - out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len) - out = rearrange(out, "(b s) h d -> b s h d", b=batch_size) - - out = rearrange(out, "b s h d -> b s (h d)") - return out diff --git a/colossalai/kernel/cuda_native/mha/utils.py b/colossalai/kernel/cuda_native/mha/utils.py deleted file mode 100644 index 5f01e3ef327d..000000000000 --- a/colossalai/kernel/cuda_native/mha/utils.py +++ /dev/null @@ -1,82 +0,0 @@ -from dataclasses import dataclass -from typing import Iterable, Tuple - -import torch -import torch.nn.functional as F -from einops import rearrange - -from colossalai.utils.device import get_current_device - - -class Unpad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): - ctx.save_for_backward(indices) - # [b, s, ...] - assert tensor.ndim >= 3 - ctx.bsz = tensor.shape[0] - out = rearrange(tensor, "b s ... -> (b s) ...") - ctx.shape = out.shape - # [ntokens, ...] - return out[indices] - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # [ntokens, ...] - grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) - grad[indices] = grad_output - grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz) - # [b, s, ...] - return grad, None - - -class Repad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): - ctx.save_for_backward(indices) - # [ntokens, ...] - tensor = tensor - out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) - # [b*s, ...] - out[indices] = tensor - return out - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # [b*s, ...] - grad = grad_output[indices] - # [ntokens, ...] - return grad, None, None, None - - -@dataclass -class SeqLenInfo: - seqlens: Iterable[int] = None - indices: torch.Tensor = None - max_seqlen: int = None - cu_seqlens: torch.Tensor = None - - @staticmethod - def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()): - if attn_mask is not None: - indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) - seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() - else: - batch_size, tgt_len = size[0], size[1] - indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device) - seqlens = torch.LongTensor([tgt_len] * batch_size, device=device) - max_seqlen = max(seqlens) - cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device) - return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens) diff --git a/colossalai/kernel/extensions/flash_attention/__init__.py b/colossalai/kernel/extensions/flash_attention/__init__.py index 34fc43392aee..3c3a8ee87bc6 100644 --- a/colossalai/kernel/extensions/flash_attention/__init__.py +++ b/colossalai/kernel/extensions/flash_attention/__init__.py @@ -1,11 +1,14 @@ -from .cuda_flash_attn_2_extension import CudaFlashAttnExtension -from .cuda_memory_efficient_attn_extension import CudaMemoryEfficentAttnExtension +from .cuda_flash_attn_2_extension import HAS_FLASH_ATTN, CudaFlashAttnExtension +from .cuda_memory_efficient_attn_extension import HAS_MEM_EFF_ATTN, CudaMemoryEfficentAttnExtension from .npu_sdpa_attn_extension import NpuSdpaAttnExtension -from .npu_triangle_attn_extension import NpuTriangleAttnExtension +from .npu_triangle_attn_extension import HAS_NPU_TRIANGLE_ATTENTION, NpuTriangleAttnExtension __all__ = [ "CudaFlashAttnExtension", "CudaMemoryEfficentAttnExtension", "NpuSdpaAttnExtension", "NpuTriangleAttnExtension", + "HAS_FLASH_ATTN", + "HAS_MEM_EFF_ATTN", + "HAS_NPU_TRIANGLE_ATTENTION", ] diff --git a/colossalai/kernel/npu/__init__.py b/colossalai/kernel/npu/__init__.py deleted file mode 100644 index 6a02c705559a..000000000000 --- a/colossalai/kernel/npu/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .mha import NPUColoAttention - -__all__ = ["NPUColoAttention"] diff --git a/colossalai/kernel/npu/mha/__init__.py b/colossalai/kernel/npu/mha/__init__.py deleted file mode 100644 index 6a02c705559a..000000000000 --- a/colossalai/kernel/npu/mha/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .mha import NPUColoAttention - -__all__ = ["NPUColoAttention"] diff --git a/colossalai/kernel/npu/mha/mha.py b/colossalai/kernel/npu/mha/mha.py deleted file mode 100644 index ac982384e518..000000000000 --- a/colossalai/kernel/npu/mha/mha.py +++ /dev/null @@ -1,80 +0,0 @@ -import math -from typing import Optional - -import torch - -from .sdpa_attn import npu_sdpa_attention -from .triangle_attn import HAS_NPU_TRIANGLE_ATTENTION - - -class NPUColoAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale: float = None): - super().__init__() - - try: - import torch_npu # noqa - except ImportError: - raise Exception("torch_npu is not installed.") - - assert ( - embed_dim % num_heads == 0 - ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." - if scale is not None: - self.scale = scale - else: - self.scale = 1 / math.sqrt(embed_dim // num_heads) - self.dropout = dropout - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - origin_attn_mask: Optional[torch.Tensor] = None, - attn_mask_type: int = None, - bias: Optional[torch.Tensor] = None, - ): - """ - Implement the scaled dot product attention with softmax. - - Arguments: - q: (batch, q_seqlen, nheads, headdim) - k: (batch, kv_seqlen, nheads, headdim) - v: (batch, kv_seqlen, nheads, headdim) - batch_size: int. - seq_len: int. - dropout_p: float. Dropout probability. - scale: float. The scaling of QK^T before applying softmax. - Default to 1. - Return: - attn_out: (batch, q_seqlen, nheads, headdim). - """ - assert ( - len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4 - ), f"query, key, value should be 4D tensors, but got {query.shape}, {key.shape}, {value.shape}" - assert ( - query.device.type == "npu" and key.device.type == "npu" and value.device.type == "npu" - ), f"query, key, value should be on npu device, but got {query.device}, {key.device}, {value.device}" - assert bias is None, "bias is not supported in npu colo attention" - - causal = attn_mask_type is not None and attn_mask_type.value > 1 - - if HAS_NPU_TRIANGLE_ATTENTION: - from .triangle_attn import npu_triangle_attention - - attn_fn = npu_triangle_attention - else: - attn_fn = npu_sdpa_attention - - out = attn_fn( - query, - key, - value, - attn_mask=attn_mask, - origin_attn_mask=origin_attn_mask, - dropout_p=self.dropout, - scale=self.scale, - is_causal=causal, - ) - return out diff --git a/colossalai/kernel/npu/mha/sdpa_attn.py b/colossalai/kernel/npu/mha/sdpa_attn.py deleted file mode 100644 index 2af1dbae2e67..000000000000 --- a/colossalai/kernel/npu/mha/sdpa_attn.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch -from einops import rearrange - - -def npu_sdpa_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - attn_mask: torch.Tensor = None, - origin_attn_mask: torch.Tensor = None, - scale: float = 1.0, - dropout_p: float = 0.0, - is_causal: bool = True, -): - """ - The scaled dot product attention. - - Arguments: - q: (batch, q_seqlen, nheads, headdim) - k: (batch, kv_seqlen, nheads, headdim) - v: (batch, kv_seqlen, nheads, headdim) - batch_size: int. - seq_len: int. - dropout_p: float. Dropout probability. - scale: float. The scaling of QK^T before applying softmax. - Default to 1. - Return: - attn_out: (batch, q_seqlen, nheads, headdim). - """ - q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)] - output = torch.nn.functional.scaled_dot_product_attention( - q, - k, - v, - attn_mask=origin_attn_mask, - dropout_p=dropout_p, - is_causal=origin_attn_mask is None, - scale=scale, - ) - output = rearrange(output, "b h s d -> b s (h d)") - return output diff --git a/colossalai/kernel/npu/mha/triangle_attn.py b/colossalai/kernel/npu/mha/triangle_attn.py deleted file mode 100644 index 619076d5f888..000000000000 --- a/colossalai/kernel/npu/mha/triangle_attn.py +++ /dev/null @@ -1,115 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -import torch -from einops import rearrange - -HAS_NPU_TRIANGLE_ATTENTION = False -try: - from torch_npu import npu_confusion_transpose, npu_scaled_masked_softmax - - HAS_NPU_TRIANGLE_ATTENTION = True -except ImportError: - logging.warning("Import torch_npu Error.") - - -if HAS_NPU_TRIANGLE_ATTENTION: - - def npu_triangle_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - attn_mask: torch.Tensor = None, - origin_attn_mask: torch.Tensor = None, - scale: float = 1.0, - dropout_p: float = 0.0, - is_causal: bool = True, - block_size=512, - ): - """ - The triangle attention reduces the attention calculation of the mask - part by dividing the q, k, and v matrices into blocks - - Arguments: - block_size: The size of the inverted triangle block, the default is 512, - the smaller the block_size, the more calculations will be reduced, - but the number of small operators will be increased - masked_softmax_func: mask function to be applied. - dropout_func: dropout function to be applied. - """ - - def compute_attn(q_layer, k_layer, v_layer, mask_tmp): - # [b, hn, q_size, hd] * [b, hn, hd, kv_size] -> [b, hn, q_size, kv_size] - cur_sim = torch.matmul(q_layer, k_layer) - attention_probs = npu_scaled_masked_softmax(cur_sim, mask_tmp) - # attention dropout - if dropout_p > 0: - attention_probs = torch.nn.functional.dropout( - attention_probs, p=dropout_p, training=attention_probs.require_grad - ) - # [b, hn, q_size, kv_size] * [b, hn, kv_size, hd] -> [b, hn, q_size, hd] - context_layer_tmp = torch.matmul(attention_probs, v_layer) - return context_layer_tmp - - q, k, v = [rearrange(x, "b s h d -> b h s d") for x in (q, k, v)] - origin_attn_mask = origin_attn_mask.to(torch.bool) - # input shape: [b, hn, sq, hd] - bsz, head_num, sequence_len, head_dim = k.shape - sparse_groups = sequence_len // block_size - # Determine whether blocks size can be divided by sequence_length - divisible_flag = sequence_len == block_size * sparse_groups - k = k.transpose(2, 3).contiguous() - if divisible_flag: - q_tmp_layers = torch.chunk(q, sparse_groups, 2) - k_tmp_layers = torch.chunk(k, sparse_groups, 3) - v_tmp_layers = torch.chunk(v, sparse_groups, 2) - else: - seq_tmp = block_size * sparse_groups - q_last = q[:, :, seq_tmp:, :].contiguous() - mask_last = origin_attn_mask[:, :, seq_tmp:, :].contiguous() - q_tmp_layers = torch.chunk(q[:, :, :seq_tmp, :], sparse_groups, 2) - k_tmp_layers = torch.chunk(k[:, :, :, :seq_tmp], sparse_groups, 3) - v_tmp_layers = torch.chunk(v[:, :, :seq_tmp, :], sparse_groups, 2) - context_list_tmp, k_tmp, v_tmp = [], (), () - for i in range(sparse_groups): - # compute slice shape of q k v for each loop - q_begin, q_end = i * block_size, (i + 1) * block_size - kv_begin, kv_end = 0, (i + 1) * block_size - q_tmp = q_tmp_layers[i] - # slice k and v - if i == 0: - k_tmp = k_tmp_layers[i].contiguous() - v_tmp = v_tmp_layers[i].contiguous() - else: - k_tmp = torch.cat((k_tmp, k_tmp_layers[i]), -1).contiguous() - v_tmp = torch.cat((v_tmp, v_tmp_layers[i]), -2).contiguous() - - mask_tmp = origin_attn_mask[:, :, q_begin:q_end, kv_begin:kv_end].contiguous() - context_layer_tmp = compute_attn(q_tmp, k_tmp, v_tmp, mask_tmp) - context_list_tmp.append(context_layer_tmp) - - if not divisible_flag: - # circumstances that cannot be divisible - context_layer_tmp = compute_attn(q_last, k, v, mask_last) - context_list_tmp.append(context_layer_tmp) - context_layer = torch.cat(context_list_tmp, 2) - new_context_layer_shape = (bsz, sequence_len, head_num * head_dim) - context_layer = npu_confusion_transpose(context_layer, [0, 2, 1, 3], [*new_context_layer_shape], True) - # ========================= - # Context layer. [b, sq, hp] - # ========================= - return context_layer diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index ec7644317903..eee3b505a22a 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -35,7 +35,7 @@ replace_return_docstrings, ) -from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN +from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe.layers import SparseMLP from colossalai.moe.manager import MOE_MANAGER diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index ea6cc666ea01..30a30e86acae 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -4,8 +4,7 @@ import torch from einops import rearrange -from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN -from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN +from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN from colossalai.testing import clear_cache_before_run, parameterize if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN: From b660c5e1a8e7bfc140e62c9398642dc50ba3fffd Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Mon, 11 Dec 2023 15:10:35 +0800 Subject: [PATCH 18/21] fix --- colossalai/kernel/__init__.py | 2 +- colossalai/kernel/cpu_adam_loader.py | 4 +- .../extensions/flash_attention/__init__.py | 5 + .../cuda_flash_attn_2_extension.py | 2 +- .../cuda_memory_efficient_attn_extension.py | 2 +- .../extensions/flash_attention/utils.py | 89 ++++++++++++++++++ colossalai/kernel/extensions/utils.py | 91 ------------------- colossalai/kernel/flash_attention_loader.py | 9 +- 8 files changed, 106 insertions(+), 98 deletions(-) create mode 100644 colossalai/kernel/extensions/flash_attention/utils.py diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py index 8a0062f7a8b0..5356fbf48c95 100644 --- a/colossalai/kernel/__init__.py +++ b/colossalai/kernel/__init__.py @@ -1,6 +1,6 @@ from .cpu_adam_loader import CPUAdamLoader from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention -from .extensions.utils import AttnMaskType +from .extensions.flash_attention import AttnMaskType from .flash_attention_loader import ColoAttention, FlashAttentionLoader __all__ = [ diff --git a/colossalai/kernel/cpu_adam_loader.py b/colossalai/kernel/cpu_adam_loader.py index f4125bf4a464..cd2a2257d7c7 100644 --- a/colossalai/kernel/cpu_adam_loader.py +++ b/colossalai/kernel/cpu_adam_loader.py @@ -20,14 +20,14 @@ class CPUAdamLoader(BaseKernelLoader): ) Args: - CPUAdamOptimizer: + func CPUAdamOptimizer: alpha (float): learning rate. Default to 1e-3. beta1 (float): coefficients used for computing running averages of gradient. Default to 0.9. beta2 (float): coefficients used for computing running averages of its square. Default to 0.99. epsilon (float): term added to the denominator to improve numerical stability. Default to 1e-8. weight_decay (float): weight decay (L2 penalty). Default to 0. adamw_mode (bool): whether to use the adamw. Default to True. - step: + func step: step (int): current step. lr (float): learning rate. beta1 (float): coefficients used for computing running averages of gradient. diff --git a/colossalai/kernel/extensions/flash_attention/__init__.py b/colossalai/kernel/extensions/flash_attention/__init__.py index 3c3a8ee87bc6..79c6935d2260 100644 --- a/colossalai/kernel/extensions/flash_attention/__init__.py +++ b/colossalai/kernel/extensions/flash_attention/__init__.py @@ -2,6 +2,7 @@ from .cuda_memory_efficient_attn_extension import HAS_MEM_EFF_ATTN, CudaMemoryEfficentAttnExtension from .npu_sdpa_attn_extension import NpuSdpaAttnExtension from .npu_triangle_attn_extension import HAS_NPU_TRIANGLE_ATTENTION, NpuTriangleAttnExtension +from .utils import AttnMaskType, Repad, SeqLenInfo, Unpad __all__ = [ "CudaFlashAttnExtension", @@ -11,4 +12,8 @@ "HAS_FLASH_ATTN", "HAS_MEM_EFF_ATTN", "HAS_NPU_TRIANGLE_ATTENTION", + "Unpad", + "AttnMaskType", + "Repad", + "SeqLenInfo", ] diff --git a/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py b/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py index e0dbd544d0a3..5cc4fb7dad02 100644 --- a/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py +++ b/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py @@ -4,7 +4,7 @@ import torch from ..base_extension import BaseExtension -from ..utils import SeqLenInfo +from .utils import SeqLenInfo def is_ampere_or_better_gpu(): diff --git a/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py b/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py index 7be7824f911a..b9ec3fe945c1 100644 --- a/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py +++ b/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py @@ -4,7 +4,7 @@ import torch from ..base_extension import BaseExtension -from ..utils import SeqLenInfo +from .utils import SeqLenInfo HAS_MEM_EFF_ATTN = False try: diff --git a/colossalai/kernel/extensions/flash_attention/utils.py b/colossalai/kernel/extensions/flash_attention/utils.py new file mode 100644 index 000000000000..0eab9e89f88b --- /dev/null +++ b/colossalai/kernel/extensions/flash_attention/utils.py @@ -0,0 +1,89 @@ +import enum +from dataclasses import dataclass +from typing import Iterable, Tuple + +import torch +import torch.nn.functional as F +from einops import rearrange + +from colossalai.utils.device import get_current_device + + +class Unpad(torch.autograd.Function): + """ + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py + """ + + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): + ctx.save_for_backward(indices) + # [b, s, ...] + assert tensor.ndim >= 3 + ctx.bsz = tensor.shape[0] + out = rearrange(tensor, "b s ... -> (b s) ...") + ctx.shape = out.shape + # [ntokens, ...] + return out[indices] + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # [ntokens, ...] + grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) + grad[indices] = grad_output + grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz) + # [b, s, ...] + return grad, None + + +class Repad(torch.autograd.Function): + """ + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py + """ + + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): + ctx.save_for_backward(indices) + # [ntokens, ...] + tensor = tensor + out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) + # [b*s, ...] + out[indices] = tensor + return out + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # [b*s, ...] + grad = grad_output[indices] + # [ntokens, ...] + return grad, None, None, None + + +@dataclass +class SeqLenInfo: + seqlens: Iterable[int] = None + indices: torch.Tensor = None + max_seqlen: int = None + cu_seqlens: torch.Tensor = None + + @staticmethod + def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()): + if attn_mask is not None: + indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) + seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() + else: + batch_size, tgt_len = size[0], size[1] + indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device) + seqlens = torch.LongTensor([tgt_len] * batch_size, device=device) + max_seqlen = max(seqlens) + cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device) + return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens) + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 + paddedcausal = 3 diff --git a/colossalai/kernel/extensions/utils.py b/colossalai/kernel/extensions/utils.py index 7e325d47b32e..3f75f952d57b 100644 --- a/colossalai/kernel/extensions/utils.py +++ b/colossalai/kernel/extensions/utils.py @@ -1,4 +1,3 @@ -import enum import os import re import subprocess @@ -228,93 +227,3 @@ def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]: if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: return nvcc_extra_args + ["--threads", "4"] return nvcc_extra_args - - -from dataclasses import dataclass -from typing import Iterable, Tuple - -import torch -import torch.nn.functional as F -from einops import rearrange - -from colossalai.utils.device import get_current_device - - -class Unpad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): - ctx.save_for_backward(indices) - # [b, s, ...] - assert tensor.ndim >= 3 - ctx.bsz = tensor.shape[0] - out = rearrange(tensor, "b s ... -> (b s) ...") - ctx.shape = out.shape - # [ntokens, ...] - return out[indices] - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # [ntokens, ...] - grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) - grad[indices] = grad_output - grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz) - # [b, s, ...] - return grad, None - - -class Repad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): - ctx.save_for_backward(indices) - # [ntokens, ...] - tensor = tensor - out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) - # [b*s, ...] - out[indices] = tensor - return out - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # [b*s, ...] - grad = grad_output[indices] - # [ntokens, ...] - return grad, None, None, None - - -@dataclass -class SeqLenInfo: - seqlens: Iterable[int] = None - indices: torch.Tensor = None - max_seqlen: int = None - cu_seqlens: torch.Tensor = None - - @staticmethod - def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()): - if attn_mask is not None: - indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) - seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() - else: - batch_size, tgt_len = size[0], size[1] - indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device) - seqlens = torch.LongTensor([tgt_len] * batch_size, device=device) - max_seqlen = max(seqlens) - cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device) - return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens) - - -class AttnMaskType(enum.Enum): - padding = 1 - causal = 2 - paddedcausal = 3 diff --git a/colossalai/kernel/flash_attention_loader.py b/colossalai/kernel/flash_attention_loader.py index d9fee3e82b27..9d292e2dd13c 100644 --- a/colossalai/kernel/flash_attention_loader.py +++ b/colossalai/kernel/flash_attention_loader.py @@ -6,12 +6,15 @@ from .base_kernel_loader import BaseKernelLoader from .extensions.flash_attention import ( + AttnMaskType, CudaFlashAttnExtension, CudaMemoryEfficentAttnExtension, NpuSdpaAttnExtension, NpuTriangleAttnExtension, + Repad, + SeqLenInfo, + Unpad, ) -from .extensions.utils import AttnMaskType, Repad, SeqLenInfo, Unpad class FlashAttentionLoader(BaseKernelLoader): @@ -47,13 +50,15 @@ def __init__(self): def fetch_kernel(self, backend: str = None): if backend is not None: + if not self._extension_map[backend]().is_available(): + raise Exception(f"{backend} is not available for flash attention.") return self._extension_map[backend]().fetch() kernel = None if self._is_cuda_available(): if CudaFlashAttnExtension().is_available(): kernel = CudaFlashAttnExtension().fetch() - elif CudaMemoryEfficentAttnExtension.is_available(): + elif CudaMemoryEfficentAttnExtension().is_available(): kernel = CudaMemoryEfficentAttnExtension().fetch() elif self._is_npu_available(): if NpuTriangleAttnExtension().is_available(): From dfff0a0c1fa623d7950d0ed2509e1118d2c47d60 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Mon, 11 Dec 2023 15:20:51 +0800 Subject: [PATCH 19/21] not use warning --- .../flash_attention/cuda_memory_efficient_attn_extension.py | 4 ++-- .../extensions/flash_attention/npu_triangle_attn_extension.py | 4 ++-- colossalai/kernel/flash_attention_loader.py | 2 ++ 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py b/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py index b9ec3fe945c1..f4f5fe7e8a98 100644 --- a/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py +++ b/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py @@ -1,9 +1,9 @@ -import warnings from typing import Optional import torch from ..base_extension import BaseExtension +from ..utils import print_rank_0 from .utils import SeqLenInfo HAS_MEM_EFF_ATTN = False @@ -85,7 +85,7 @@ def build(self): def is_available(self): if HAS_MEM_EFF_ATTN == False: - warnings.warn("ImportError: please install xformers from https://github.com/facebookresearch/xformers") + print_rank_0("ImportError: please install xformers from https://github.com/facebookresearch/xformers") return HAS_MEM_EFF_ATTN def load(self): diff --git a/colossalai/kernel/extensions/flash_attention/npu_triangle_attn_extension.py b/colossalai/kernel/extensions/flash_attention/npu_triangle_attn_extension.py index 5401b824051d..bf70b635e61c 100644 --- a/colossalai/kernel/extensions/flash_attention/npu_triangle_attn_extension.py +++ b/colossalai/kernel/extensions/flash_attention/npu_triangle_attn_extension.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings import torch from einops import rearrange from ..base_extension import BaseExtension +from ..utils import print_rank_0 HAS_NPU_TRIANGLE_ATTENTION = False try: @@ -133,7 +133,7 @@ def build(self): def is_available(self): if HAS_NPU_TRIANGLE_ATTENTION == False: - warnings.warn( + print_rank_0( "ImportError: please install latest torch_npu with 'npu_confusion_transpose' and 'npu_scaled_masked_softmax' api." ) return HAS_NPU_TRIANGLE_ATTENTION diff --git a/colossalai/kernel/flash_attention_loader.py b/colossalai/kernel/flash_attention_loader.py index 9d292e2dd13c..e35969ba2c3b 100644 --- a/colossalai/kernel/flash_attention_loader.py +++ b/colossalai/kernel/flash_attention_loader.py @@ -15,6 +15,7 @@ SeqLenInfo, Unpad, ) +from .extensions.utils import print_rank_0 class FlashAttentionLoader(BaseKernelLoader): @@ -118,6 +119,7 @@ def forward( if self.attn.__name__ == "flash_attention" and ( query.dtype not in [torch.float16, torch.bfloat16] or bias != None ): + print_rank_0("flash attention is not applicable, switch to memory effcient attention") self.attn = FlashAttentionLoader().fetch_kernel(backend="cuda_memory_efficent_attn") padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 From bab51b0c8b62548ff46877f2e29256e7f46a97ad Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Mon, 11 Dec 2023 16:03:32 +0800 Subject: [PATCH 20/21] update --- colossalai/kernel/base_kernel_loader.py | 25 +++---------------- colossalai/kernel/cpu_adam_loader.py | 9 ++++--- .../kernel/extensions/base_extension.py | 16 ++---------- .../extensions/cpu_adam/arm_extension.py | 8 +++--- .../extensions/cpu_adam/x86_extension.py | 8 +++--- .../cuda_flash_attn_2_extension.py | 9 +++---- .../cuda_memory_efficient_attn_extension.py | 5 ++-- .../npu_sdpa_attn_extension.py | 5 ++-- .../npu_triangle_attn_extension.py | 5 ++-- colossalai/kernel/flash_attention_loader.py | 23 +++++++++-------- 10 files changed, 41 insertions(+), 72 deletions(-) diff --git a/colossalai/kernel/base_kernel_loader.py b/colossalai/kernel/base_kernel_loader.py index 7977228391a2..21b553e0ce1a 100644 --- a/colossalai/kernel/base_kernel_loader.py +++ b/colossalai/kernel/base_kernel_loader.py @@ -1,8 +1,6 @@ -import platform from abc import ABC, abstractmethod -from typing import Dict, List - -import torch +from collections import OrderedDict +from typing import List from .extensions.base_extension import BaseExtension @@ -14,7 +12,7 @@ class BaseKernelLoader(ABC): kernel = kernel_loader.load() """ - def __init__(self, extension_map: Dict[str, BaseExtension], supported_device: List[str]): + def __init__(self, extension_map: OrderedDict[str, BaseExtension], supported_device: List[str]): self._extension_map = extension_map self._supported_device = supported_device @@ -29,20 +27,3 @@ def fetch_kernel(self): def load(self): self.run_checks() return self.fetch_kernel() - - def _is_x86_available(self) -> bool: - return platform.processor() == "x86_64" - - def _is_arm_available(self) -> bool: - return platform.processor() == "aarch64" - - def _is_cuda_available(self) -> bool: - return torch.cuda.is_available() - - def _is_npu_available(self) -> bool: - try: - import torch_npu # noqa - - return torch.npu.is_available() - except: - return False diff --git a/colossalai/kernel/cpu_adam_loader.py b/colossalai/kernel/cpu_adam_loader.py index cd2a2257d7c7..0df6bd49b4c9 100644 --- a/colossalai/kernel/cpu_adam_loader.py +++ b/colossalai/kernel/cpu_adam_loader.py @@ -1,3 +1,6 @@ +import platform +from collections import OrderedDict + from .base_kernel_loader import BaseKernelLoader from .extensions.cpu_adam import ArmCPUAdamExtension, X86CPUAdamExtension @@ -44,7 +47,7 @@ class CPUAdamLoader(BaseKernelLoader): def __init__(self): super().__init__( - extension_map=dict( + extension_map=OrderedDict( arm=ArmCPUAdamExtension, x86=X86CPUAdamExtension, ), @@ -52,9 +55,9 @@ def __init__(self): ) def fetch_kernel(self): - if self._is_x86_available(): + if platform.machine() == "x86_64": kernel = self._extension_map["x86"]().fetch() - elif self._is_arm_available(): + elif platform.machine() in ["aarch64", "aarch64_be", "armv8b", "armv8l"]: kernel = self._extension_map["arm"]().fetch() else: raise Exception("not supported") diff --git a/colossalai/kernel/extensions/base_extension.py b/colossalai/kernel/extensions/base_extension.py index dd070a238292..8905dbf13180 100644 --- a/colossalai/kernel/extensions/base_extension.py +++ b/colossalai/kernel/extensions/base_extension.py @@ -4,7 +4,7 @@ class BaseExtension(ABC): @abstractmethod - def build_completed(self) -> bool: + def requires_build(self) -> bool: pass @abstractmethod @@ -16,18 +16,6 @@ def load(self) -> Callable: pass def fetch(self) -> Callable: - if not self.build_completed: + if self.requires_build: self.build() return self.load() - - -class CUDAExtension(BaseExtension): - pass - - -class TritonExtension(BaseExtension): - pass - - -class NPUExtension(BaseExtension): - pass diff --git a/colossalai/kernel/extensions/cpu_adam/arm_extension.py b/colossalai/kernel/extensions/cpu_adam/arm_extension.py index d76755f7fbd5..9868059bfcfd 100644 --- a/colossalai/kernel/extensions/cpu_adam/arm_extension.py +++ b/colossalai/kernel/extensions/cpu_adam/arm_extension.py @@ -6,15 +6,15 @@ class ArmCPUAdamExtension(BaseExtension): def __init__(self) -> None: super().__init__() self.kernel_builder = ArmCPUAdamBuilder() - self._is_build_completed = False + self._requires_build = False @property - def build_completed(self): - return self._is_build_completed + def requires_build(self) -> bool: + return self._requires_build def build(self): self.kernel_builder.build() - self._is_build_completed = True + self._requires_build = True def load(self): return self.kernel_builder.load() diff --git a/colossalai/kernel/extensions/cpu_adam/x86_extension.py b/colossalai/kernel/extensions/cpu_adam/x86_extension.py index 024cb633147a..687c91f35759 100644 --- a/colossalai/kernel/extensions/cpu_adam/x86_extension.py +++ b/colossalai/kernel/extensions/cpu_adam/x86_extension.py @@ -7,15 +7,15 @@ class X86CPUAdamExtension(BaseExtension): def __init__(self) -> None: super().__init__() self.kernel_builder = X86CPUAdamBuilder() - self._is_build_completed = False + self._requires_build = False @property - def build_completed(self): - return self._is_build_completed + def requires_build(self) -> bool: + return self._requires_build def build(self): self.kernel_builder.build() - self._is_build_completed = True + self._requires_build = True def load(self): return self.kernel_builder.load() diff --git a/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py b/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py index 5cc4fb7dad02..99c3536063fb 100644 --- a/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py +++ b/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py @@ -1,9 +1,9 @@ -import warnings from typing import Optional import torch from ..base_extension import BaseExtension +from ..utils import print_rank_0 from .utils import SeqLenInfo @@ -83,18 +83,17 @@ def flash_attention( class CudaFlashAttnExtension(BaseExtension): def __init__(self) -> None: super().__init__() - self._is_build_completed = True @property - def build_completed(self): - return self._is_build_completed + def requires_build(self): + return False def build(self): pass def is_available(self): if HAS_FLASH_ATTN == False: - warnings.warn(ERROR_MSG) + print_rank_0(ERROR_MSG) return HAS_FLASH_ATTN def load(self): diff --git a/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py b/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py index f4f5fe7e8a98..4954ab5b1d28 100644 --- a/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py +++ b/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py @@ -74,11 +74,10 @@ def mem_eff_attention( class CudaMemoryEfficentAttnExtension(BaseExtension): def __init__(self) -> None: super().__init__() - self._is_build_completed = True @property - def build_completed(self): - return self._is_build_completed + def requires_build(self) -> bool: + return False def build(self): pass diff --git a/colossalai/kernel/extensions/flash_attention/npu_sdpa_attn_extension.py b/colossalai/kernel/extensions/flash_attention/npu_sdpa_attn_extension.py index be94a7b587ae..7dc9d9b9b118 100644 --- a/colossalai/kernel/extensions/flash_attention/npu_sdpa_attn_extension.py +++ b/colossalai/kernel/extensions/flash_attention/npu_sdpa_attn_extension.py @@ -48,11 +48,10 @@ def npu_sdpa_attention( class NpuSdpaAttnExtension(BaseExtension): def __init__(self) -> None: super().__init__() - self._is_build_completed = True @property - def build_completed(self): - return self._is_build_completed + def requires_build(self) -> bool: + return False def build(self): pass diff --git a/colossalai/kernel/extensions/flash_attention/npu_triangle_attn_extension.py b/colossalai/kernel/extensions/flash_attention/npu_triangle_attn_extension.py index bf70b635e61c..a760f56a195e 100644 --- a/colossalai/kernel/extensions/flash_attention/npu_triangle_attn_extension.py +++ b/colossalai/kernel/extensions/flash_attention/npu_triangle_attn_extension.py @@ -122,11 +122,10 @@ def compute_attn(q_layer, k_layer, v_layer, mask_tmp): class NpuTriangleAttnExtension(BaseExtension): def __init__(self) -> None: super().__init__() - self._is_build_completed = True @property - def build_completed(self): - return self._is_build_completed + def requires_build(self) -> bool: + return False def build(self): pass diff --git a/colossalai/kernel/flash_attention_loader.py b/colossalai/kernel/flash_attention_loader.py index e35969ba2c3b..3d0cd397540a 100644 --- a/colossalai/kernel/flash_attention_loader.py +++ b/colossalai/kernel/flash_attention_loader.py @@ -1,9 +1,12 @@ import math +from collections import OrderedDict from typing import Optional import torch from einops import rearrange +from colossalai.accelerator import get_accelerator + from .base_kernel_loader import BaseKernelLoader from .extensions.flash_attention import ( AttnMaskType, @@ -40,7 +43,8 @@ class FlashAttentionLoader(BaseKernelLoader): def __init__(self): super().__init__( - extension_map=dict( + # extension name must start with the accelerator name. E.g. npu_xxx, cuda_xxx + extension_map=OrderedDict( cuda_flash_attn=CudaFlashAttnExtension, cuda_memory_efficent_attn=CudaMemoryEfficentAttnExtension, npu_sdpa_attn=NpuSdpaAttnExtension, @@ -56,16 +60,13 @@ def fetch_kernel(self, backend: str = None): return self._extension_map[backend]().fetch() kernel = None - if self._is_cuda_available(): - if CudaFlashAttnExtension().is_available(): - kernel = CudaFlashAttnExtension().fetch() - elif CudaMemoryEfficentAttnExtension().is_available(): - kernel = CudaMemoryEfficentAttnExtension().fetch() - elif self._is_npu_available(): - if NpuTriangleAttnExtension().is_available(): - kernel = NpuTriangleAttnExtension().fetch() - else: - kernel = NpuSdpaAttnExtension().fetch() + accelerator_name = get_accelerator().name + assert accelerator_name in self._supported_device, f"{accelerator_name} is not supported for flash attention." + for extension_name, extension in self._extension_map.items(): + if extension_name.startswith(accelerator_name): + if extension().is_available(): + kernel = extension().fetch() + break if kernel is None: raise Exception("No extension for flash attention is supported") return kernel From 1b8b31341cc56db54228081024c43135abfa7966 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Mon, 11 Dec 2023 18:14:35 +0800 Subject: [PATCH 21/21] update --- colossalai/kernel/base_kernel_loader.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/colossalai/kernel/base_kernel_loader.py b/colossalai/kernel/base_kernel_loader.py index 21b553e0ce1a..ff7a4326133a 100644 --- a/colossalai/kernel/base_kernel_loader.py +++ b/colossalai/kernel/base_kernel_loader.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod -from collections import OrderedDict -from typing import List +from typing import Dict, List from .extensions.base_extension import BaseExtension @@ -12,7 +11,7 @@ class BaseKernelLoader(ABC): kernel = kernel_loader.load() """ - def __init__(self, extension_map: OrderedDict[str, BaseExtension], supported_device: List[str]): + def __init__(self, extension_map: Dict[str, BaseExtension], supported_device: List[str]): self._extension_map = extension_map self._supported_device = supported_device