Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions colossalai/kernel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from .cpu_adam_loader import CPUAdamLoader
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
from .extensions.flash_attention import AttnMaskType
from .flash_attention_loader import ColoAttention, FlashAttentionLoader

__all__ = [
"LayerNorm",
"FusedScaleMaskSoftmax",
"MultiHeadAttention",
"CPUAdamLoader",
"FlashAttentionLoader",
"ColoAttention",
"AttnMaskType",
]
28 changes: 28 additions & 0 deletions colossalai/kernel/base_kernel_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from abc import ABC, abstractmethod
from typing import Dict, List

from .extensions.base_extension import BaseExtension


class BaseKernelLoader(ABC):
"""
Usage:
kernel_loader = KernelLoader()
kernel = kernel_loader.load()
"""

def __init__(self, extension_map: Dict[str, BaseExtension], supported_device: List[str]):
self._extension_map = extension_map
self._supported_device = supported_device

def run_checks(self):
# run supported device check and other possible checks
pass

@abstractmethod
def fetch_kernel(self):
pass

def load(self):
self.run_checks()
return self.fetch_kernel()
64 changes: 64 additions & 0 deletions colossalai/kernel/cpu_adam_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import platform
from collections import OrderedDict

from .base_kernel_loader import BaseKernelLoader
from .extensions.cpu_adam import ArmCPUAdamExtension, X86CPUAdamExtension


class CPUAdamLoader(BaseKernelLoader):
"""
CPU Adam Loader

Usage:
# init
cpu_adam = CPUAdamLoader().load()
cpu_adam_op = cpu_adam.CPUAdamOptimizer(
alpha, beta1, beta2, epsilon, weight_decay, adamw_mode,
)
...
# optim step
cpu_adam_op.step(
step, lr, beta1, beta2, epsilon, weight_decay, bias_correction,
params, grads, exp_avg, exp_avg_sq, loss_scale,
)

Args:
func CPUAdamOptimizer:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the arg of this class. You may remove this add add type hint in CPUAdamOptimizer or add a pyi file

alpha (float): learning rate. Default to 1e-3.
beta1 (float): coefficients used for computing running averages of gradient. Default to 0.9.
beta2 (float): coefficients used for computing running averages of its square. Default to 0.99.
epsilon (float): term added to the denominator to improve numerical stability. Default to 1e-8.
weight_decay (float): weight decay (L2 penalty). Default to 0.
adamw_mode (bool): whether to use the adamw. Default to True.
func step:
step (int): current step.
lr (float): learning rate.
beta1 (float): coefficients used for computing running averages of gradient.
beta2 (float): coefficients used for computing running averages of its square.
epsilon (float): term added to the denominator to improve numerical stability.
weight_decay (float): weight decay (L2 penalty).
bias_correction (bool): whether to use bias correction.
params (torch.Tensor): parameter.
grads (torch.Tensor): gradient.
exp_avg (torch.Tensor): exp average.
exp_avg_sq (torch.Tensor): exp average square.
loss_scale (float): loss scale value.
"""

def __init__(self):
super().__init__(
extension_map=OrderedDict(
arm=ArmCPUAdamExtension,
x86=X86CPUAdamExtension,
),
supported_device=["cpu"],
)

def fetch_kernel(self):
if platform.machine() == "x86_64":
kernel = self._extension_map["x86"]().fetch()
elif platform.machine() in ["aarch64", "aarch64_be", "armv8b", "armv8l"]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kernel only supports aarch64 now

kernel = self._extension_map["arm"]().fetch()
else:
raise Exception("not supported")
return kernel
2 changes: 0 additions & 2 deletions colossalai/kernel/cuda_native/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .layer_norm import MixedFusedLayerNorm as LayerNorm
from .mha.mha import ColoAttention
from .multihead_attention import MultiHeadAttention
from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax

Expand All @@ -8,6 +7,5 @@
"MultiHeadAttention",
"FusedScaleMaskSoftmax",
"ScaledUpperTriangMaskedSoftmax",
"ColoAttention",
"AttnMaskType",
]
3 changes: 0 additions & 3 deletions colossalai/kernel/cuda_native/mha/__init__.py

This file was deleted.

114 changes: 0 additions & 114 deletions colossalai/kernel/cuda_native/mha/mha.py

This file was deleted.

Empty file.
21 changes: 21 additions & 0 deletions colossalai/kernel/extensions/base_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from abc import ABC, abstractmethod
from typing import Callable


class BaseExtension(ABC):
@abstractmethod
def requires_build(self) -> bool:
pass

@abstractmethod
def build(self) -> None:
pass

@abstractmethod
def load(self) -> Callable:
pass

def fetch(self) -> Callable:
if self.requires_build:
self.build()
return self.load()
4 changes: 4 additions & 0 deletions colossalai/kernel/extensions/cpu_adam/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .arm_extension import ArmCPUAdamExtension
from .x86_extension import X86CPUAdamExtension

__all__ = ["ArmCPUAdamExtension", "X86CPUAdamExtension"]
53 changes: 53 additions & 0 deletions colossalai/kernel/extensions/cpu_adam/arm_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from ..base_extension import BaseExtension
from ..extension_builder import ExtensionBuilder


class ArmCPUAdamExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
self.kernel_builder = ArmCPUAdamBuilder()
self._requires_build = False

@property
def requires_build(self) -> bool:
return self._requires_build

def build(self):
self.kernel_builder.build()
self._requires_build = True

def load(self):
return self.kernel_builder.load()


class ArmCPUAdamBuilder(ExtensionBuilder):
NAME = "arm_cpu_adam"
PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam"
ext_type = "cpu"

def __init__(self):
super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH)
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]

# necessary 4 functions
def sources_files(self):
ret = [
self.csrc_abs_path("cpu_adam_arm.cpp"),
]
return ret

def include_dirs(self):
return [self.csrc_abs_path("includes")]

def cxx_flags(self):
extra_cxx_flags = [
"-std=c++14",
"-std=c++17",
"-g",
"-Wno-reorder",
"-fopenmp",
]
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags

def nvcc_flags(self):
return []
65 changes: 65 additions & 0 deletions colossalai/kernel/extensions/cpu_adam/x86_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from ..base_extension import BaseExtension
from ..extension_builder import ExtensionBuilder
from ..utils import append_nvcc_threads


class X86CPUAdamExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
self.kernel_builder = X86CPUAdamBuilder()
self._requires_build = False

@property
def requires_build(self) -> bool:
return self._requires_build

def build(self):
self.kernel_builder.build()
self._requires_build = True

def load(self):
return self.kernel_builder.load()


class X86CPUAdamBuilder(ExtensionBuilder):
NAME = "cpu_adam"
PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam"

def __init__(self):
super().__init__(name=X86CPUAdamBuilder.NAME, prebuilt_import_path=X86CPUAdamBuilder.PREBUILT_IMPORT_PATH)
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]

# necessary 4 functions
def sources_files(self):
ret = [
self.csrc_abs_path("cpu_adam.cpp"),
]
return ret

def include_dirs(self):
return [self.csrc_abs_path("includes"), self.get_cuda_home_include()]

def cxx_flags(self):
extra_cxx_flags = [
"-std=c++14",
"-std=c++17",
"-lcudart",
"-lcublas",
"-g",
"-Wno-reorder",
"-fopenmp",
"-march=native",
]
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags

def nvcc_flags(self):
extra_cuda_flags = [
"-std=c++14",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
]
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags
return append_nvcc_threads(ret)
Loading