-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[npu] use extension for op builder #5172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
9c667fb
update extension
oahzxl f903fe6
update cpu adam
oahzxl 6478272
update is
oahzxl c8e69e7
Merge branch 'opbuilder' of https://github.com/oahzxl/ColossalAI into…
oahzxl cc6c21c
add doc for cpu adam
oahzxl 7f8979a
update kernel
oahzxl d0bc49f
update commit
oahzxl 0b01dd2
update flash
oahzxl eebeab3
update memory efficient
oahzxl 8606483
update flash attn
oahzxl d2c6e23
update flash attention loader
oahzxl 0cb447d
update api
oahzxl 1a7c9ce
fix
oahzxl 982474b
update doc
oahzxl 1502573
update example time limit
oahzxl e53408a
reverse change
oahzxl 2f77365
fix doc
oahzxl 13b98b9
remove useless kernel
oahzxl b660c5e
fix
oahzxl dfff0a0
not use warning
oahzxl bab51b0
update
oahzxl 1b8b313
update
oahzxl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,14 @@ | ||
| from .cpu_adam_loader import CPUAdamLoader | ||
| from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention | ||
| from .extensions.flash_attention import AttnMaskType | ||
| from .flash_attention_loader import ColoAttention, FlashAttentionLoader | ||
|
|
||
| __all__ = [ | ||
| "LayerNorm", | ||
| "FusedScaleMaskSoftmax", | ||
| "MultiHeadAttention", | ||
| "CPUAdamLoader", | ||
| "FlashAttentionLoader", | ||
| "ColoAttention", | ||
| "AttnMaskType", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| from abc import ABC, abstractmethod | ||
| from typing import Dict, List | ||
|
|
||
| from .extensions.base_extension import BaseExtension | ||
|
|
||
|
|
||
| class BaseKernelLoader(ABC): | ||
| """ | ||
| Usage: | ||
| kernel_loader = KernelLoader() | ||
| kernel = kernel_loader.load() | ||
| """ | ||
|
|
||
| def __init__(self, extension_map: Dict[str, BaseExtension], supported_device: List[str]): | ||
| self._extension_map = extension_map | ||
| self._supported_device = supported_device | ||
|
|
||
| def run_checks(self): | ||
| # run supported device check and other possible checks | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def fetch_kernel(self): | ||
| pass | ||
|
|
||
| def load(self): | ||
| self.run_checks() | ||
| return self.fetch_kernel() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| import platform | ||
| from collections import OrderedDict | ||
|
|
||
| from .base_kernel_loader import BaseKernelLoader | ||
| from .extensions.cpu_adam import ArmCPUAdamExtension, X86CPUAdamExtension | ||
|
|
||
|
|
||
| class CPUAdamLoader(BaseKernelLoader): | ||
| """ | ||
| CPU Adam Loader | ||
|
|
||
| Usage: | ||
| # init | ||
| cpu_adam = CPUAdamLoader().load() | ||
| cpu_adam_op = cpu_adam.CPUAdamOptimizer( | ||
| alpha, beta1, beta2, epsilon, weight_decay, adamw_mode, | ||
| ) | ||
| ... | ||
| # optim step | ||
| cpu_adam_op.step( | ||
| step, lr, beta1, beta2, epsilon, weight_decay, bias_correction, | ||
| params, grads, exp_avg, exp_avg_sq, loss_scale, | ||
| ) | ||
|
|
||
| Args: | ||
| func CPUAdamOptimizer: | ||
| alpha (float): learning rate. Default to 1e-3. | ||
| beta1 (float): coefficients used for computing running averages of gradient. Default to 0.9. | ||
| beta2 (float): coefficients used for computing running averages of its square. Default to 0.99. | ||
| epsilon (float): term added to the denominator to improve numerical stability. Default to 1e-8. | ||
| weight_decay (float): weight decay (L2 penalty). Default to 0. | ||
| adamw_mode (bool): whether to use the adamw. Default to True. | ||
| func step: | ||
| step (int): current step. | ||
| lr (float): learning rate. | ||
| beta1 (float): coefficients used for computing running averages of gradient. | ||
| beta2 (float): coefficients used for computing running averages of its square. | ||
| epsilon (float): term added to the denominator to improve numerical stability. | ||
| weight_decay (float): weight decay (L2 penalty). | ||
| bias_correction (bool): whether to use bias correction. | ||
| params (torch.Tensor): parameter. | ||
| grads (torch.Tensor): gradient. | ||
| exp_avg (torch.Tensor): exp average. | ||
| exp_avg_sq (torch.Tensor): exp average square. | ||
| loss_scale (float): loss scale value. | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| super().__init__( | ||
| extension_map=OrderedDict( | ||
| arm=ArmCPUAdamExtension, | ||
| x86=X86CPUAdamExtension, | ||
| ), | ||
| supported_device=["cpu"], | ||
| ) | ||
|
|
||
| def fetch_kernel(self): | ||
| if platform.machine() == "x86_64": | ||
| kernel = self._extension_map["x86"]().fetch() | ||
| elif platform.machine() in ["aarch64", "aarch64_be", "armv8b", "armv8l"]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This kernel only supports aarch64 now |
||
| kernel = self._extension_map["arm"]().fetch() | ||
| else: | ||
| raise Exception("not supported") | ||
| return kernel | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| from abc import ABC, abstractmethod | ||
| from typing import Callable | ||
|
|
||
|
|
||
| class BaseExtension(ABC): | ||
| @abstractmethod | ||
| def requires_build(self) -> bool: | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def build(self) -> None: | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def load(self) -> Callable: | ||
| pass | ||
|
|
||
| def fetch(self) -> Callable: | ||
| if self.requires_build: | ||
| self.build() | ||
| return self.load() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from .arm_extension import ArmCPUAdamExtension | ||
| from .x86_extension import X86CPUAdamExtension | ||
|
|
||
| __all__ = ["ArmCPUAdamExtension", "X86CPUAdamExtension"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| from ..base_extension import BaseExtension | ||
| from ..extension_builder import ExtensionBuilder | ||
|
|
||
|
|
||
| class ArmCPUAdamExtension(BaseExtension): | ||
| def __init__(self) -> None: | ||
| super().__init__() | ||
| self.kernel_builder = ArmCPUAdamBuilder() | ||
| self._requires_build = False | ||
|
|
||
| @property | ||
| def requires_build(self) -> bool: | ||
| return self._requires_build | ||
|
|
||
| def build(self): | ||
| self.kernel_builder.build() | ||
| self._requires_build = True | ||
|
|
||
| def load(self): | ||
| return self.kernel_builder.load() | ||
|
|
||
|
|
||
| class ArmCPUAdamBuilder(ExtensionBuilder): | ||
| NAME = "arm_cpu_adam" | ||
| PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam" | ||
| ext_type = "cpu" | ||
|
|
||
| def __init__(self): | ||
| super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH) | ||
| self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] | ||
|
|
||
| # necessary 4 functions | ||
| def sources_files(self): | ||
| ret = [ | ||
| self.csrc_abs_path("cpu_adam_arm.cpp"), | ||
| ] | ||
| return ret | ||
|
|
||
| def include_dirs(self): | ||
| return [self.csrc_abs_path("includes")] | ||
|
|
||
| def cxx_flags(self): | ||
| extra_cxx_flags = [ | ||
| "-std=c++14", | ||
| "-std=c++17", | ||
| "-g", | ||
| "-Wno-reorder", | ||
| "-fopenmp", | ||
| ] | ||
| return ["-O3"] + self.version_dependent_macros + extra_cxx_flags | ||
|
|
||
| def nvcc_flags(self): | ||
| return [] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| from ..base_extension import BaseExtension | ||
| from ..extension_builder import ExtensionBuilder | ||
| from ..utils import append_nvcc_threads | ||
|
|
||
|
|
||
| class X86CPUAdamExtension(BaseExtension): | ||
| def __init__(self) -> None: | ||
| super().__init__() | ||
| self.kernel_builder = X86CPUAdamBuilder() | ||
| self._requires_build = False | ||
|
|
||
| @property | ||
| def requires_build(self) -> bool: | ||
| return self._requires_build | ||
|
|
||
| def build(self): | ||
| self.kernel_builder.build() | ||
| self._requires_build = True | ||
|
|
||
| def load(self): | ||
| return self.kernel_builder.load() | ||
|
|
||
|
|
||
| class X86CPUAdamBuilder(ExtensionBuilder): | ||
| NAME = "cpu_adam" | ||
| PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam" | ||
|
|
||
| def __init__(self): | ||
| super().__init__(name=X86CPUAdamBuilder.NAME, prebuilt_import_path=X86CPUAdamBuilder.PREBUILT_IMPORT_PATH) | ||
| self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] | ||
|
|
||
| # necessary 4 functions | ||
| def sources_files(self): | ||
| ret = [ | ||
| self.csrc_abs_path("cpu_adam.cpp"), | ||
| ] | ||
| return ret | ||
|
|
||
| def include_dirs(self): | ||
| return [self.csrc_abs_path("includes"), self.get_cuda_home_include()] | ||
|
|
||
| def cxx_flags(self): | ||
| extra_cxx_flags = [ | ||
| "-std=c++14", | ||
| "-std=c++17", | ||
| "-lcudart", | ||
| "-lcublas", | ||
| "-g", | ||
| "-Wno-reorder", | ||
| "-fopenmp", | ||
| "-march=native", | ||
| ] | ||
| return ["-O3"] + self.version_dependent_macros + extra_cxx_flags | ||
|
|
||
| def nvcc_flags(self): | ||
| extra_cuda_flags = [ | ||
| "-std=c++14", | ||
| "-std=c++17", | ||
| "-U__CUDA_NO_HALF_OPERATORS__", | ||
| "-U__CUDA_NO_HALF_CONVERSIONS__", | ||
| "-U__CUDA_NO_HALF2_OPERATORS__", | ||
| "-DTHRUST_IGNORE_CUB_VERSION_CHECK", | ||
| ] | ||
| ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags | ||
| return append_nvcc_threads(ret) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not the arg of this class. You may remove this add add type hint in CPUAdamOptimizer or add a pyi file