Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions colossalai/kernel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
from .triton import llama_context_attn_fwd, bloom_context_attn_fwd
from .triton import softmax
from .triton import copy_kv_cache_to_dest

__all__ = [
"LayerNorm",
"FusedScaleMaskSoftmax",
"MultiHeadAttention",
"llama_context_attn_fwd",
"bloom_context_attn_fwd",
"softmax",
"copy_kv_cache_to_dest",
]
3 changes: 3 additions & 0 deletions colossalai/kernel/triton/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .context_attention import llama_context_attn_fwd, bloom_context_attn_fwd
from .softmax import softmax
from .copy_kv_cache_dest import copy_kv_cache_to_dest
4 changes: 2 additions & 2 deletions tests/test_infer_ops/triton/test_bloom_context_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
try:
import triton
import triton.language as tl
from tests.test_kernels.triton.utils import benchmark, torch_context_attention
from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd
from tests.test_infer_ops.triton.utils import benchmark, torch_context_attention
from colossalai.kernel.triton import bloom_context_attn_fwd
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
Expand Down
4 changes: 2 additions & 2 deletions tests/test_infer_ops/triton/test_llama_context_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
try:
import triton
import triton.language as tl
from tests.test_kernels.triton.utils import benchmark, torch_context_attention
from colossalai.kernel.triton.context_attention import llama_context_attn_fwd
from tests.test_infer_ops.triton.utils import benchmark, torch_context_attention
from colossalai.kernel.triton import llama_context_attn_fwd
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
Expand Down