diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py index 8933fc0a3c2f..a99cb497c3e7 100644 --- a/colossalai/kernel/__init__.py +++ b/colossalai/kernel/__init__.py @@ -1,7 +1,14 @@ from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention +from .triton import llama_context_attn_fwd, bloom_context_attn_fwd +from .triton import softmax +from .triton import copy_kv_cache_to_dest __all__ = [ "LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention", + "llama_context_attn_fwd", + "bloom_context_attn_fwd", + "softmax", + "copy_kv_cache_to_dest", ] diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py new file mode 100644 index 000000000000..9655d720406a --- /dev/null +++ b/colossalai/kernel/triton/__init__.py @@ -0,0 +1,3 @@ +from .context_attention import llama_context_attn_fwd, bloom_context_attn_fwd +from .softmax import softmax +from .copy_kv_cache_dest import copy_kv_cache_to_dest diff --git a/tests/test_infer_ops/triton/test_bloom_context_attention.py b/tests/test_infer_ops/triton/test_bloom_context_attention.py index 6c10ee3ffe3f..63d77ce3e16e 100644 --- a/tests/test_infer_ops/triton/test_bloom_context_attention.py +++ b/tests/test_infer_ops/triton/test_bloom_context_attention.py @@ -9,8 +9,8 @@ try: import triton import triton.language as tl - from tests.test_kernels.triton.utils import benchmark, torch_context_attention - from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd + from tests.test_infer_ops.triton.utils import benchmark, torch_context_attention + from colossalai.kernel.triton import bloom_context_attn_fwd HAS_TRITON = True except ImportError: HAS_TRITON = False diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py index 04d08140815d..e7446b289acd 100644 --- a/tests/test_infer_ops/triton/test_llama_context_attention.py +++ b/tests/test_infer_ops/triton/test_llama_context_attention.py @@ -9,8 +9,8 @@ try: import triton import triton.language as tl - from tests.test_kernels.triton.utils import benchmark, torch_context_attention - from colossalai.kernel.triton.context_attention import llama_context_attn_fwd + from tests.test_infer_ops.triton.utils import benchmark, torch_context_attention + from colossalai.kernel.triton import llama_context_attn_fwd HAS_TRITON = True except ImportError: HAS_TRITON = False