Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ jobs:
cd TensorNVMe
conda install cmake
pip install -r requirements.txt
pip install -v .
DISABLE_URING=1 pip install -v .

- name: Store TensorNVMe Cache
run: |
Expand Down Expand Up @@ -201,4 +201,4 @@ jobs:
uses: actions/upload-artifact@v3
with:
name: report
path: report/
path: report/
2 changes: 1 addition & 1 deletion .github/workflows/build_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
cd TensorNVMe
conda install cmake
pip install -r requirements.txt
pip install -v .
DISABLE_URING=1 pip install -v .

- uses: actions/checkout@v2
if: steps.check-avai.outputs.avai == 'true'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/compatiblity_test_on_dispatch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
cd TensorNVMe
apt update && apt install -y cmake
pip install -r requirements.txt
pip install -v .
DISABLE_URING=1 pip install -v .
- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/compatiblity_test_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
cd TensorNVMe
apt update && apt install -y cmake
pip install -r requirements.txt
pip install -v .
DISABLE_URING=1 pip install -v .
- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/compatiblity_test_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
cd TensorNVMe
apt update && apt install -y cmake
pip install -r requirements.txt
pip install -v .
DISABLE_URING=1 pip install -v .
- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
Expand Down
24 changes: 20 additions & 4 deletions colossalai/kernel/kernel_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
CpuAdamX86Extension,
FlashAttentionDaoCudaExtension,
FlashAttentionNpuExtension,
FlashAttentionXformersCudaExtension,
FlashAttentionSdpaCudaExtension,
FusedOptimizerCudaExtension,
LayerNormCudaExtension,
MoeCudaExtension,
Expand Down Expand Up @@ -65,9 +65,9 @@ def load(self, ext_name: str = None):
else:
usable_exts = []
for ext in exts:
if ext.is_hardware_available():
if ext.is_available():
# make sure the machine is compatible during kernel loading
ext.assert_hardware_compatible()
ext.assert_compatible()
usable_exts.append(ext)

assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine."
Expand Down Expand Up @@ -106,4 +106,20 @@ class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader):


class FlashAttentionLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension]
REGISTRY = [
FlashAttentionNpuExtension,
FlashAttentionDaoCudaExtension,
FlashAttentionSdpaCudaExtension,
]


class FlashAttentionWithPaddingMaskLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension]


class FlashAttentionWithCustomMaskLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]


class FlashAttentionForFloatAndCustomMaskLoader(KernelLoader):
REGISTRY = [FlashAttentionSdpaCudaExtension]
209 changes: 0 additions & 209 deletions colossalai/nn/layer/colo_attention.py

This file was deleted.

3 changes: 3 additions & 0 deletions colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .attn import AttnMaskType, ColoAttention
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row
Expand All @@ -23,4 +24,6 @@
"FusedRMSNorm",
"FusedLinear1D_Col",
"ParallelModule",
"AttnMaskType",
"ColoAttention",
]
Loading