From 064a8f2427bc92a53a4fd4ba9824b0f867e8abc9 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 10 Aug 2023 16:34:28 +0800 Subject: [PATCH 01/31] adding kernels --- .../attention_infer_kernels/bert_padding.py | 132 ++++ .../flash_attention.py | 52 ++ .../flash_attn/flash_api.cpp | 420 +++++++++++++ .../flash_attn/src/block_info.h | 41 ++ .../flash_attn/src/flash.h | 144 +++++ .../src/flash_fwd_hdim128_bf16_sm80.cu | 19 + .../src/flash_fwd_hdim128_fp16_sm80.cu | 32 + .../src/flash_fwd_hdim160_bf16_sm80.cu | 17 + .../src/flash_fwd_hdim160_fp16_sm80.cu | 27 + .../src/flash_fwd_hdim192_bf16_sm80.cu | 16 + .../src/flash_fwd_hdim192_fp16_sm80.cu | 27 + .../src/flash_fwd_hdim224_bf16_sm80.cu | 9 + .../src/flash_fwd_hdim224_fp16_sm80.cu | 9 + .../src/flash_fwd_hdim256_bf16_sm80.cu | 9 + .../src/flash_fwd_hdim256_fp16_sm80.cu | 9 + .../src/flash_fwd_hdim32_bf16_sm80.cu | 10 + .../src/flash_fwd_hdim32_fp16_sm80.cu | 23 + .../src/flash_fwd_hdim64_bf16_sm80.cu | 19 + .../src/flash_fwd_hdim64_fp16_sm80.cu | 26 + .../src/flash_fwd_hdim96_bf16_sm80.cu | 17 + .../src/flash_fwd_hdim96_fp16_sm80.cu | 23 + .../flash_attn/src/flash_fwd_kernel.h | 584 ++++++++++++++++++ .../src/flash_fwd_launch_template.h | 251 ++++++++ .../flash_attn/src/kernel_traits.h | 366 +++++++++++ .../flash_attn/src/kernel_traits_sm90.h | 159 +++++ .../flash_attn/src/philox.cuh | 165 +++++ .../flash_attn/src/softmax.h | 272 ++++++++ .../flash_attn/src/static_switch.h | 66 ++ .../flash_attn/src/utils.h | 388 ++++++++++++ .../flash_attn_interface.py | 386 ++++++++++++ .../csrc/attention_infer_kernels/linear.py | 48 ++ .../attention_infer_kernels/linear/gemm.cu | 40 ++ .../linear/linear_op.cpp | 206 ++++++ .../rotary_embedding/pos_encoding.cpp | 15 + .../rotary_embedding/pos_encoding_kernels.cu | 86 +++ .../softmax/fused_softmax.cpp | 46 ++ .../softmax/scaled_masked_softmax.h | 338 ++++++++++ .../softmax/scaled_masked_softmax_cuda.cu | 78 +++ .../softmax/type_shim.h | 20 + colossalai/shardformer/setup.py | 233 +++++++ 40 files changed, 4828 insertions(+) create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/bert_padding.py create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attention.py create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/flash_api.cpp create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/block_info.h create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash.h create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_kernel.h create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_launch_template.h create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits.h create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits_sm90.h create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/philox.cuh create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/softmax.h create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/static_switch.h create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/utils.h create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn_interface.py create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear.py create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear/gemm.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear/linear_op.cpp create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/fused_softmax.cpp create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/scaled_masked_softmax.h create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/scaled_masked_softmax_cuda.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/type_shim.h create mode 100644 colossalai/shardformer/setup.py diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/bert_padding.py b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/bert_padding.py new file mode 100644 index 000000000000..6826949a2604 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/bert_padding.py @@ -0,0 +1,132 @@ +# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py + +import torch +import torch.nn.functional as F + +from einops import rearrange, repeat + + +class IndexFirstAxis(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + # return input[indices] + return torch.gather(rearrange(input, 'b ... -> b (...)'), 0, + repeat(indices, 'z -> z d', d=second_dim)).reshape(-1, *other_shape) + + @staticmethod + def backward(ctx, grad_output): + indices, = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, 'b ... -> b (...)') + grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]], + device=grad_output.device, dtype=grad_output.dtype) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + # grad_input[indices] = grad_output + grad_input.scatter_(0, repeat(indices, 'z -> z d', d=grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis = IndexFirstAxis.apply + + +class IndexPutFirstAxis(torch.autograd.Function): + + @staticmethod + def forward(ctx, values, indices, first_axis_dim): + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, + dtype=values.dtype) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + output[indices] = values + # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) + return output + + @staticmethod + def backward(ctx, grad_output): + indices, = ctx.saved_tensors + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + grad_values = grad_output[indices] + # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis.apply + + +class IndexFirstAxisResidual(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + output = input[indices] + # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last + # memory format to channel_first. In other words, input might not be contiguous. + # If we don't detach, Pytorch complains about output being a view and is being modified inplace + return output, input.detach() + + @staticmethod + def backward(ctx, grad_output, grad_residual): + indices, = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + assert grad_residual.shape[1:] == other_shape + grad_input = grad_residual + # grad_input[indices] += grad_output + indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) + indices = indices.expand_as(grad_output) + grad_input.scatter_add_(0, indices, grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis_residual = IndexFirstAxisResidual.apply + + +def unpad_input(hidden_states, attention_mask): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return (index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'), indices), indices, + cu_seqlens, max_seqlen_in_batch) + + +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz) + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[-1] + # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) + # output[indices] = hidden_states + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, '(b s) ... -> b s ...', b=batch) \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attention.py b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attention.py new file mode 100644 index 000000000000..d70703f65f6a --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attention.py @@ -0,0 +1,52 @@ +try: + from col_flash_attn_2_lib import flash_fwd, varlen_flash_fwd + HAS_FLASH_CUDA = True +except: + HAS_FLASH_CUDA = False + print("in order to use flash-attention, make sure you install cuda kernels in op directory") + + +if HAS_FLASH_CUDA: + def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_fwd( + q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None + ) + return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state + + def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal, return_softmax): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = varlen_flash_fwd( + q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, + softmax_scale, False, causal, return_softmax, None + ) + + return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state + + + def flash_attention_fwd(qkv, scale, causal = True, return_softmax = False): + assert qkv.is_contiguous() + batches = qkv.shape[0] + d_model = qkv.shape[-1] // 3 + num_of_heads = d_model // head_size + + q = qkv[:, :, :d_model] + k = qkv[:, :, d_model:d_model * 2] + v = qkv[:, :, d_model * 2:] + q = q.view(batches, -1, num_of_heads, head_size) + k = k.view(batches, -1, num_of_heads, head_size) + v = v.view(batches, -1, num_of_heads, head_size) + + out_flash, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(q, k, v, 0, + softmax_scale = scale, + causal = causal, + return_softmax = return_softmax + ) + + if return_softmax: + return out_flash, softmax_lse + else: + return out_flash diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/flash_api.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/flash_api.cpp new file mode 100644 index 000000000000..07252a3c85bf --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/flash_api.cpp @@ -0,0 +1,420 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#include +#include +#include + +#include + +#include "flash.h" +#include "static_switch.h" + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + + +void set_params_fprop(Flash_fwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + at::Tensor out, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *p_d, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + bool is_causal) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.is_bf16 = q.dtype() == torch::kBFloat16; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(-3); + params.k_row_stride = k.stride(-3); + params.v_row_stride = v.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = k.stride(-2); + params.v_head_stride = v.stride(-2); + params.o_ptr = out.data_ptr(); + params.o_row_stride = out.stride(-3); + params.o_head_stride = out.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = q.stride(0); + params.k_batch_stride = k.stride(0); + params.v_batch_stride = v.stride(0); + params.o_batch_stride = out.stride(0); + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + + // P = softmax(QK^T) + params.p_ptr = p_d; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.h_h_k_ratio = h / h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; + TORCH_CHECK(p_dropout < 1.f); + + params.is_causal = is_causal; +} + +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + FP16_SWITCH(!params.is_bf16, [&] { + FWD_HEADDIM_SWITCH(params.d, [&] { + run_mha_fwd_(params, stream); + }); + }); +} + +std::vector +mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size + const float p_dropout, + const float softmax_scale, + const bool is_causal, + const bool return_softmax, + c10::optional gen_) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + // We will support Turing in the near future + // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + if (q_dtype == torch::kBFloat16) { + TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + } + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + + TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device"); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + const int seqlen_q = sizes[1]; + const int num_heads = sizes[2]; + const int head_size_og = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); + + at::Tensor q_padded, k_padded, v_padded; + if (head_size_og % 8 != 0) { + q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + q_padded = q; + k_padded = k; + v_padded = v; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device"); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); + if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } + } else { + out = torch::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + + auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor p; + // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { + TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); + p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); + } + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q_padded, k_padded, v_padded, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + return_softmax ? p.data_ptr() : nullptr, + softmax_lse.data_ptr(), + p_dropout, + softmax_scale, + is_causal); + + // number of times random will be generated per thread, to offset philox counter in thc random + // state + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = params.b * params.h * 32; + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + // Forward kernel will populate memory with the seed and offset. + params.rng_state = reinterpret_cast(rng_state.data_ptr()); + + if (p_dropout > 0.0) { + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + params.philox_args = gen->philox_cuda_state(counter_offset); + } + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); + + at::Tensor out_padded = out; + if (head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + } + + return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; +} + +std::vector +mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + const int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + const bool return_softmax, + c10::optional gen_) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + // We will support Turing in the near future + // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + if (q_dtype == torch::kBFloat16) { + TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + } + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + + TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(cu_seqlens_q.is_cuda(), "cu_seqlens_q must be on CUDA device"); + TORCH_CHECK(cu_seqlens_k.is_cuda(), "cu_seqlens_k must be on CUDA device"); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(cu_seqlens_q.is_contiguous(), "cu_seqlens_q must be contiguous"); + TORCH_CHECK(cu_seqlens_k.is_contiguous(), "cu_seqlens_k must be contiguous"); + + const auto sizes = q.sizes(); + + const int total_q = sizes[0]; + const int batch_size = cu_seqlens_q.numel() - 1; + const int num_heads = sizes[1]; + const int head_size_og = sizes[2]; + const int total_k = k.size(0); + const int num_heads_k = k.size(1); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + CHECK_SHAPE(q, total_q, num_heads, head_size_og); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + at::Tensor q_padded, k_padded, v_padded; + if (head_size_og % 8 != 0) { + q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + q_padded = q; + k_padded = k; + v_padded = v; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device"); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, total_q, num_heads, head_size_og); + if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } + } else { + out = torch::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + + auto softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor p; + // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { + TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); + p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); + } + + if (zero_tensors) { + out.zero_(); + softmax_lse.fill_(-std::numeric_limits::infinity()); + if (return_softmax) {p.zero_();} + } + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q_padded, k_padded, v_padded, out, + cu_seqlens_q.data_ptr(), + cu_seqlens_k.data_ptr(), + return_softmax ? p.data_ptr() : nullptr, + softmax_lse.data_ptr(), + p_dropout, + softmax_scale, + is_causal); + + // number of times random will be generated per thread, to offset philox counter in thc random + // state + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = params.b * params.h * 32; + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + // Forward kernel will populate memory with the seed and offset. + params.rng_state = reinterpret_cast(rng_state.data_ptr()); + + if (p_dropout > 0.0) { + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + params.philox_args = gen->philox_cuda_state(counter_offset); + } + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); + + at::Tensor out_padded = out; + if (head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + } + + return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "FlashAttention"; + m.def("flash_fwd", &mha_fwd, "Forward pass"); + m.def("varlen_flash_fwd", &mha_varlen_fwd, "Forward pass (variable length)"); +} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/block_info.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/block_info.h new file mode 100644 index 000000000000..94251a41e43b --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/block_info.h @@ -0,0 +1,41 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfo { + + template + __device__ BlockInfo(const Params ¶ms, const int bidb) + : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) + , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]) + , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + , actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) + { + } + + template + inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + } + + template + inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + } + + const int sum_s_q; + const int sum_s_k; + const uint32_t actual_seqlen_q; + const uint32_t actual_seqlen_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash.h new file mode 100644 index 000000000000..e65d7d536aa9 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash.h @@ -0,0 +1,144 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +#ifdef OLD_GENERATOR_PATH +#include +#else +#include +#endif + +#include + + +constexpr int TOTAL_DIM = 0; +constexpr int H_DIM = 1; +constexpr int D_DIM = 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = uint32_t; + // The QKV matrices. + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + + // The O matrix (output). + void * __restrict__ o_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the P matrix. + void * __restrict__ p_ptr; + + // The pointer to the softmax sum. + void * __restrict__ softmax_lse_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + // array of length b+1 holding starting offset of each sequence. + int * __restrict__ cu_seqlens_q; + int * __restrict__ cu_seqlens_k; + + int *__restrict__ blockmask; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_softmax_rp_dropout; + + // Random state. + at::PhiloxCudaState philox_args; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t * rng_state; + + bool is_bf16; + bool is_causal; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_bwd_params : public Flash_fwd_params { + + // The dO and dQKV matrices. + void *__restrict__ do_ptr; + void *__restrict__ dq_ptr; + void *__restrict__ dk_ptr; + void *__restrict__ dv_ptr; + + // To accumulate dQ + void *__restrict__ dq_accum_ptr; + void *__restrict__ dk_accum_ptr; + void *__restrict__ dv_accum_ptr; + + // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q + // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ + // dv_accum_ptr; + + // The stride between rows of the dO, dQ, dK and dV matrices. + // TD [2022-04-16]: We're using 32-bit indexing to save registers. + // The code probably won't work for arrays larger than 2GB. + index_t do_batch_stride; + index_t do_row_stride; + index_t do_head_stride; + index_t dq_batch_stride; + index_t dk_batch_stride; + index_t dv_batch_stride; + index_t dq_row_stride; + index_t dk_row_stride; + index_t dv_row_stride; + index_t dq_head_stride; + index_t dk_head_stride; + index_t dv_head_stride; + + // The pointer to the softmax d sum. + void *__restrict__ dsoftmax_sum; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); + +template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure); diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu new file mode 100644 index 000000000000..654400a74919 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu @@ -0,0 +1,19 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// if (params.p_dropout == 1.f) { +// run_flash_fwd, false>(params, stream); +// } else { +// run_flash_fwd, true>(params, stream); +// } +// } +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu new file mode 100644 index 000000000000..5b7254a918d7 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu @@ -0,0 +1,32 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// if (params.p_dropout == 1.f) { +// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k +// run_flash_fwd, false>(params, stream); +// // run_flash_fwd, false>(params, stream); +// // run_flash_fwd, false>(params, stream); +// // run_flash_fwd, false>(params, stream); +// run_flash_fwd, false>(params, stream); +// run_flash_fwd, false>(params, stream); +// run_flash_fwd, false>(params, stream); +// // 1st ones are good for H100, A100 +// // 2nd one is good for A6000 bc we get slightly better occupancy +// } else { +// run_flash_fwd, true>(params, stream); +// run_flash_fwd, true>(params, stream); +// run_flash_fwd, true>(params, stream); +// // 1st one is good for H100, A100, A6000 +// } +// } + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu new file mode 100644 index 000000000000..6a9d60c39156 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu @@ -0,0 +1,17 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd, Is_dropout>(params, stream); +// }); +// } +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu new file mode 100644 index 000000000000..6c40a164d6d8 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd, Is_dropout>(params, stream); +// run_flash_fwd, Is_dropout>(params, stream); +// run_flash_fwd, Is_dropout>(params, stream); +// run_flash_fwd, Is_dropout>(params, stream); +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest. +// // For A100, H100, 1st is fastest. +// }); +// } +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu new file mode 100644 index 000000000000..d2f4cba71528 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd, Is_dropout>(params, stream); +// }); +// } +template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu new file mode 100644 index 000000000000..2875c92660a9 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd, Is_dropout>(params, stream); +// run_flash_fwd, Is_dropout>(params, stream); +// run_flash_fwd, Is_dropout>(params, stream); +// // This one is slightly faster for causal? +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// }); +// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout +// // For A6000, 1st is faster when causal, 3rd is faster when not causal +// } +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu new file mode 100644 index 000000000000..982fe7eadecc --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu new file mode 100644 index 000000000000..4c083f7b663f --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu new file mode 100644 index 000000000000..cb074a95ed8c --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu new file mode 100644 index 000000000000..ddf5e132293d --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu new file mode 100644 index 000000000000..81e359e16feb --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu new file mode 100644 index 000000000000..91e6331e90cc --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu @@ -0,0 +1,23 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd, Is_dropout>(params, stream); +// // For dropout there might be a lot of register spilling? +// // These two are very slow due to register spilling +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// // This one is slightly slower +// // run_flash_fwd>(params, stream); +// }); +// } +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu new file mode 100644 index 000000000000..fffcbebb5d98 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu @@ -0,0 +1,19 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// if (params.p_dropout == 1.f) { +// run_flash_fwd, false>(params, stream); +// } else { +// run_flash_fwd, true>(params, stream); +// } +// } +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu new file mode 100644 index 000000000000..01bd1716720b --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// if (params.p_dropout == 1.f) { +// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower +// // Using block size (64 x 256) is 27% slower for seqlen=2k +// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling +// run_flash_fwd, false>(params, stream); +// run_flash_fwd, false>(params, stream); +// run_flash_fwd, false>(params, stream); +// } else { +// run_flash_fwd, true>(params, stream); +// run_flash_fwd, true>(params, stream); +// run_flash_fwd, true>(params, stream); +// } +// } +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu new file mode 100644 index 000000000000..b0b27db59600 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu @@ -0,0 +1,17 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd, Is_dropout>(params, stream); +// }); +// } +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu new file mode 100644 index 000000000000..820b63cbbfd5 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu @@ -0,0 +1,23 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd, Is_dropout>(params, stream); +// run_flash_fwd, Is_dropout>(params, stream); +// // This 3rd one is good for H100, and A100, A6000 +// run_flash_fwd, Is_dropout>(params, stream); +// run_flash_fwd, Is_dropout>(params, stream); +// // These two are always slower +// // run_flash_fwd>(params, stream); +// // run_flash_fwd>(params, stream); +// }); +// } +template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_kernel.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_kernel.h new file mode 100644 index 000000000000..6e7364776b22 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_kernel.h @@ -0,0 +1,584 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +#include "block_info.h" +#include "kernel_traits.h" +#include "utils.h" +#include "softmax.h" +#include "philox.cuh" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) { + using TileShape_MNK = typename TiledMMA::TiledShape_MNK; + using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; + constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; + constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; + constexpr int MMAStride_M = MMA_M * AtomShape_M; + auto t = make_tile(Layout, Int>, + Stride<_1, Int> >{}, + make_layout(size<2>(TileShape_MNK{}))); + // if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); } + return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) { + using TileShape_MNK = typename TiledMMA::TiledShape_MNK; + using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; + constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; + constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; + constexpr int MMAStride_M = MMA_M * AtomShape_M; + auto t = make_tile(Layout, Int>, + Stride<_1, Int> >{}, + // TODO: Shouldn't this be size<1>? + make_layout(size<2>(TileShape_MNK{}))); + // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); } + return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, + Tensor2 &acc_o, float softmax_scale_log2) { + if (Is_first) { + flash::template reduce_max(scores, scores_max); + flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + flash::reduce_sum(scores, scores_sum); + } else { + Tensor scores_max_prev = make_fragment_like(scores_max); + copy(scores_max, scores_max_prev); + flash::template reduce_max(scores, scores_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + #pragma unroll + for (int mi = 0; mi < size(scores_max); ++mi) { + float scores_max_cur = !Check_inf + ? scores_max(mi) + : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + scores_sum(mi) *= scores_scale; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } + } + flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + Tensor scores_sum_cur = make_fragment_like(scores_sum); + flash::reduce_sum(scores, scores_sum_cur); + #pragma unroll + for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void write_softmax_to_gmem( + Tensor const &tOrP, Tensor &tPgP, TiledCopy gmem_thr_copy_P +) { + // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) + Layout l = tOrP.layout(); + Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); + CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{}); + CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP)); + #pragma unroll + for (int mi = 0; mi < size<1>(tPrP); ++mi) { + copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + // The global block index. + const int block_id = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (Is_causal) { + n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); + // } + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), + Shape, Int>{}, + make_stride(params.seqlen_k_rounded, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx); + auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + Tensor tPgP = gmem_thr_copy_P.partition_D(gP); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); + // auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); + // if (cute::thread0()) {smem_thr_copy_Q.print_all();} + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} + + auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // TODO: this might need to change if we change the mma instruction in SM70 + Tensor scores_max = make_tensor(Shape(acc_o)>>{}); + Tensor scores_sum = make_fragment_like(scores_max); + + // + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) + // if (cute::thread0()) { + // print(tScQ.layout()); printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<0>(tScQ(i))); + // } + // printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<1>(tScQ(i))); + // } + // printf("\n"); + // } + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + Tensor tQrQ = make_fragment_like(tQgQ); + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } + + // // Copy rmem to smem + // // copy(tQrQ, tQsQ); + // flash::cp_async_wait<0>(); + // __syncthreads(); + // // if (cute::thread(1, 0)) { print(tQsQ); } + // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{}); + // // if (cute::thread0()) { print(sQNoSwizzle); } + + if (Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<0>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } + // __syncthreads(); + + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<1>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); + } + + auto seeds = at::cuda::philox::unpack(params.philox_args); + unsigned long long seed = std::get<0>(seeds); + unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; + + // Save seed and offset for backward. + if (block_id == 0 && tidx == 0) { + params.rng_state[0] = seed; + params.rng_state[1] = std::get<1>(seeds); + } + + clear(acc_o); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1; + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + // if (cute::thread0()) { print(scores); } + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + if (!Is_causal) { + if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } + } else { + // Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) + // Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N) + // static_assert(decltype(size<0>(taccScS))::value == 4); + // // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices. + // Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0); + // Tensor idx_rowcol = make_tensor(taccScS.data(), flash::convert_layout_acc_rowcol(taccScS.layout())); + // flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM); + // Idk why it's get<1> and not get<0> of the stride. + // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } + // I can't get the stride from idx_row + flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + kNWarps * 16); + // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); + // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > 0) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + uint32_t block_col_idx = n_block * (kBlockN / 32); + if (Return_softmax) { + Tensor tOrP_copy = make_fragment_like(tOrP); + copy(tOrP, tOrP_copy); + flash::apply_dropout( + tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps + ); + flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + tPgP.data() = tPgP.data() + (-kBlockN); + } + if (Is_dropout) { + flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps); + } + // if (cute::thread0()) { print(tOrP); } + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= 0) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= 0; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > 0) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + uint32_t block_col_idx = n_block * (kBlockN / 32); + if (Return_softmax) { + Tensor tOrP_copy = make_fragment_like(tOrP); + copy(tOrP, tOrP_copy); + flash::apply_dropout( + tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps + ); + flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + tPgP.data() = tPgP.data() + (-kBlockN); + } + if (Is_dropout) { + flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps); + } + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); + } + + // Epilogue + + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + Tensor lse = make_fragment_like(scores_sum); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * params.rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + + // if (cute::thread0()) { print(acc_o_rowcol); } + + // Convert acc_o from fp32 to fp16/bf16 + Tensor rO = flash::convert_type(acc_o); + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); + // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } + + copy(smem_thr_copy_O, taccOrO, taccOsO); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + Tensor tOrO = make_tensor(shape(tOgO)); + copy(gmem_thr_copy_O, tOsO, tOrO); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn(const Params ¶ms) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting + // them to have the same number of threads or have to traverse the attention matrix + // in the same order. + // In the Philox RNG, we use the offset to store the batch, head, and the lane id + // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within + // the attention matrix. This way, as long as we have the batch, head, and the location of + // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. + + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_launch_template.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_launch_template.h new file mode 100644 index 000000000000..f48186aeb333 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_launch_template.h @@ -0,0 +1,251 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include "static_switch.h" +#include "flash.h" +#include "flash_fwd_kernel.h" + +template +__global__ void flash_fwd_kernel(Flash_fwd_params params) { + flash::compute_attn(params); +} + +template +void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr size_t smem_size = Kernel_traits::kSmemSize; + // printf("smem_size = %d\n", smem_size); + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.b, params.h); + // We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check + // for cu_seqlens_q as well. + const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + const bool return_softmax = params.p_ptr != nullptr; + BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + // Will only return softmax if dropout, to reduce compilation time. + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); +} + +template +void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 32; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); + }); +} + +template +void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 64; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if constexpr(!Is_dropout) { + // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower + // Using block size (64 x 256) is 27% slower for seqlen=2k + // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + }); + }); +} + +template +void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 96; + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm8x = dprops->major == 8 && dprops->minor > 0; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // These two are always slower + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); + }); +} + +template +void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 128; + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm8x = dprops->major == 8 && dprops->minor > 0; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if constexpr(!Is_dropout) { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // 1st ones are good for H100, A100 + // 2nd one is good for A6000 bc we get slightly better occupancy + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + }); + }); +} + +template +void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 160; + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm8x = dprops->major == 8 && dprops->minor > 0; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, H100, 128 x 32 is the fastest. + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 64 with 8 warps is the fastest for non-causal. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); + }); +} + +template +void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 192; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if constexpr(!Is_dropout) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); + }); +} + +template +void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 224; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + // printf("max_smem_per_block = %d\n", max_smem_per_block); + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32. + // If we have N = 32, there are only 1024 elements to load at once, where each load + // is 8 elements. This means we can only use 128 threads and not 256 threads. + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); + }); +} + +template +void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr int Headdim = 256; + int device; + cudaGetDevice(&device); + int max_smem_per_sm, max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); + status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, we want to run with 128 x 64 (128KB smem). + // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // 64 KB + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // 96 KB + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); + }); +} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits.h new file mode 100644 index 000000000000..3468e4bffc37 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits.h @@ -0,0 +1,366 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/algorithm/copy.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include + +using namespace cute; + +template +struct Flash_kernel_traits { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = uint32_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >; + using ValLayoutMNK = Layout>; +#else + using MMA_Atom_Arch = MMA_Atom; + using ValLayoutMNK = Layout>; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQ = decltype( + composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutAtomVtransposed = decltype( + composition(Swizzle{}, + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutVtransposed = decltype(tile_to_shape( + SmemLayoutAtomVtransposed{}, + Shape, Int>{})); + // Maybe the VtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + + using SmemLayoutAtomO = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + + static constexpr int kSmemQCount = size(SmemLayoutQ{}); + static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; + static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); + using GmemLayoutAtomP = Layout, Int>, + Stride, _1>>; + + using GmemTiledCopyP = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomP{}, + Layout>{})); // Val layout, 8 vals per store + +}; + +// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. +// No_double_buffer is another option to reduce smem usage, but will slow things down. +template > +struct Flash_bwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Is_V_in_regs = Is_V_in_regs_; + static constexpr bool No_double_buffer = No_double_buffer_; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; + static_assert(kNWarps % AtomLayoutMSdP == 0); + static_assert(kNWarps % AtomLayoutNdKV == 0); + static_assert(kNWarps % AtomLayoutMdQ == 0); + + using TiledMmaSdP = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using TiledMmadKV = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using TiledMmadQ = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQdO = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutQdO = decltype(tile_to_shape( + SmemLayoutAtomQdO{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutAtomKV = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutKV = decltype(tile_to_shape( + // SmemLayoutAtomQdO{}, + SmemLayoutAtomKV{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutAtomKtransposed = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutKtransposed = decltype(tile_to_shape( + SmemLayoutAtomKtransposed{}, + make_shape(Int{}, Int{}))); + // Maybe the KtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); + + // TODO: generalize to other values of kBlockN + // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 + // static constexpr int kPBlockN = kBlockN; + static_assert(kBlockN >= 64); + // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. + static constexpr int kPBlockN = 64; + static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); + // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); + static constexpr int kSwizzlePdS = 3; + using SmemLayoutAtomPdS = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + make_shape(Int{}, Int{}))); + using SmemLayoutAtomPdStransposed = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutPdStransposed = decltype(tile_to_shape( + SmemLayoutAtomPdStransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); + using SmemCopyAtomPdS = Copy_Atom; + + using SmemLayoutAtomQdOtransposed = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutQdOtransposed = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); + + using SmemLayoutAtomdKV = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdKV = decltype(tile_to_shape( + SmemLayoutAtomdKV{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdKV = Copy_Atom; + + using SmemLayoutAtomdQ = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape( + SmemLayoutAtomdQ{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdQ = Copy_Atom; + + static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ + static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; + static constexpr int kSmemdSCount = size(SmemLayoutPdS{}); + static constexpr int kSmemPCount = size(SmemLayoutPdS{}); + static constexpr int kSmemdQCount = size(SmemLayoutdQ{}); + static constexpr int kSmemdPsumCount = kBlockM; + static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); + static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); + static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); + static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); + static constexpr int kSmemSize = kSmemQdOSize + + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); + static constexpr int kSmemSize1colblock = kSmemQdOSize + + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + kSmemPSize + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3 + + kSmemdSSize + kSmemPSize; + + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem + // to affect speed in practice. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopydO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomdQaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + using GmemTiledCopydQaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomdQaccum{}, + Layout>{})); // Val layout, 4 vals per store + + using GmemTiledCopydQaccumAtomicAdd = decltype( + make_tiled_copy(Copy_Atom{}, + Layout, // Thread layout, 8 threads per row + Stride<_32, _1>>{}, + Layout>{})); // Val layout, 1 val per store + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits_sm90.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits_sm90.h new file mode 100644 index 000000000000..e07f383904a8 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits_sm90.h @@ -0,0 +1,159 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/algorithm/copy.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include + +using namespace cute; + +template +struct Flash_kernel_traits_sm90 { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = uint32_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >; + using ValLayoutMNK = Layout>; +#else + using MMA_Atom_Arch = MMA_Atom; + using ValLayoutMNK = Layout>; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQ = decltype( + composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutAtomVtransposed = decltype( + composition(Swizzle{}, + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutVtransposed = decltype(tile_to_shape( + SmemLayoutAtomVtransposed{}, + Shape, Int>{})); + // Maybe the VtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + + using SmemLayoutAtomO = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + + static constexpr int kSmemQCount = size(SmemLayoutQ{}); + static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; + static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); + using GmemLayoutAtomP = Layout, Int>, + Stride, _1>>; + + using GmemTiledCopyP = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomP{}, + Layout>{})); // Val layout, 8 vals per store + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/philox.cuh b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/philox.cuh new file mode 100644 index 000000000000..6ce1440f288d --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/philox.cuh @@ -0,0 +1,165 @@ +// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h +#pragma once +// Philox CUDA. + +namespace flash { + +struct ull2 { + unsigned long long x; + unsigned long long y; +}; + +inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { + uint2 *res; + unsigned long long tmp; + asm ("mul.wide.u32 %0, %1, %2;\n\t" + : "=l"(tmp) + : "r"(a), "r"(b)); + res = (uint2*)(&tmp); + return *res; +} + +inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { + constexpr unsigned long kPhiloxSA = 0xD2511F53; + constexpr unsigned long kPhiloxSB = 0xCD9E8D57; + uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); + uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); + uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; + return ret; +} + +inline __device__ uint4 philox(unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) { + constexpr unsigned long kPhilox10A = 0x9E3779B9; + constexpr unsigned long kPhilox10B = 0xBB67AE85; + uint2 key = reinterpret_cast(seed); + uint4 counter; + ull2 *tmp = reinterpret_cast(&counter); + tmp->x = offset; + tmp->y = subsequence; + #pragma unroll + for (int i = 0; i < 6; i++) { + counter = philox_single_round(counter, key); + key.x += (kPhilox10A); + key.y += (kPhilox10B); + } + uint4 output = philox_single_round(counter, key); + return output; +} + +} // namespace flash + +namespace { + +class Philox { +public: + __device__ inline Philox(unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) + : STATE(0) + , seed_(seed) + , offset_(offset) + , key(reinterpret_cast(seed)) { + //key.x = (unsigned int)seed; + //key.y = (unsigned int)(seed >> 32); + //counter = make_uint4(0, 0, 0, 0); + //counter.z = (unsigned int)(subsequence); + //counter.w = (unsigned int)(subsequence >> 32); + //STATE = 0; + //incr_n(offset / 4); + + // key = reinterpret_cast(seed); + ull2 * tmp = reinterpret_cast(&counter); + tmp->x = offset / 4; + tmp->y = subsequence; + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w); + // } + } + __device__ inline uint4 operator()() { + // // if (STATE == 0) { + // uint4 counter_ = counter; + // uint2 key_ = key; + // // 7-round philox + // #pragma unroll + // for (int i = 0; i < 6; i++) { + // counter_ = flash::philox_single_round(counter_, key_); + // key_.x += (kPhilox10A); + // key_.y += (kPhilox10B); + // } + // // output = philox_single_round(counter_, key_); + // uint4 output = flash::philox_single_round(counter_, key_); + // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w); + // // } + // incr(); + // // } + // // return a float4 directly + // // unsigned long ret; + // // switch(STATE) { + // // case 0: ret = output.x; break; + // // case 1: ret = output.y; break; + // // case 2: ret = output.z; break; + // // case 3: ret = output.w; break; + // //} + // // STATE = (STATE + 1) % 4; + // return output; + return flash::philox(seed_, offset_, offset_); + } + +private: + unsigned long long offset_, seed_; + struct ull2 { + uint64_t x; + uint64_t y; + }; + uint4 counter; + // uint4 output; + const uint2 key; + unsigned int STATE; + __device__ inline void incr_n(unsigned long long n) { + unsigned int nlo = (unsigned int)(n); + unsigned int nhi = (unsigned int)(n >> 32); + counter.x += nlo; + if (counter.x < nlo) + nhi++; + counter.y += nhi; + if (nhi <= counter.y) + return; + if (++counter.z) + return; + ++counter.w; + } + + __device__ uint4 incr128 (uint4 ctr) + { + uint4 res; + asm ("add.cc.u32 %0, %4, %8;\n\t" + "addc.cc.u32 %1, %5, %9;\n\t" + "addc.cc.u32 %2, %6, %10;\n\t" + "addc.u32 %3, %7, %11;\n\t" + : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w) + : "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w), + "n"(1), "n"(0), "n"(0), "n"(0)); + return res; + } + + __device__ inline void incr() { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // } + counter = incr128(counter); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // } + } + + static const unsigned long kPhilox10A = 0x9E3779B9; + static const unsigned long kPhilox10B = 0xBB67AE85; + // static const unsigned long kPhiloxSA = 0xD2511F53; + // static const unsigned long kPhiloxSB = 0xCD9E8D57; +}; + +} // namespace diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/softmax.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/softmax.h new file mode 100644 index 000000000000..3e9a7b4597c6 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/softmax.h @@ -0,0 +1,272 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include + +#include +#include + +#include "philox.cuh" +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ inline void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++){ + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ inline void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ inline void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ inline void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +inline __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +// Apply the exp to all the elements. +template +inline __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +template +inline __device__ void apply_mask(Tensor &tensor, const uint32_t max_seqlen_k) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const uint32_t lane_id = threadIdx.x % 32; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2; + if (col_idx >= max_seqlen_k) { + // Without the "make_coord" we get wrong results + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } +} + +template +inline __device__ void apply_mask_causal(Tensor &tensor, const uint32_t col_idx_offset_, + const uint32_t max_seqlen_k, const uint32_t row_idx_offset_, + const uint32_t warp_row_stride) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const uint32_t lane_id = threadIdx.x % 32; + // const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4; + const uint32_t row_idx_offset = row_idx_offset_; + const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const uint32_t row_idx = row_idx_base + i * 8; + const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const uint32_t col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const uint32_t col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } + } + } +} + +template +inline __device__ void apply_mask_causal_w_idx( + Tensor &tensor, Tensor const &idx_rowcol, + const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_) +{ + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 2, "Only support 2D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); + CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); + #pragma unroll + for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { + if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { + tensor(mi, ni) = -INFINITY; + } + } + // if (cute::thread0()) { + // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); + // print(tensor(_, make_coord(j, ni))); + // // print(tensor(_, j + ni * size<1, 0>(tensor))); + // } + } +} + +template +inline __device__ void apply_dropout(Tensor &tensor, uint8_t p_dropout_in_uint8_t, + unsigned long long seed, unsigned long long offset, + uint32_t block_row_start, uint32_t block_col_start, + uint32_t block_row_stride) { + // tensor has shape (8, MMA_M, MMA_N / 2) + using T = typename Engine::value_type; + auto encode_dropout = [](bool keep, T val) { + return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); + }; + static_assert(decltype(size<2>(tensor))::value % 2 == 0); + const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); + const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); + // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } + #pragma unroll + for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { + uint2 rowcol = make_uint2(block_row_start, block_col_start); + #pragma unroll + for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { + // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} + uint4 random_uint4 = flash::philox(seed, reinterpret_cast(rowcol), offset); + // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} + uint8_t (&rnd_8)[16] = reinterpret_cast(random_uint4); + // Special implementation for 16-bit types: we duplicate the threshold to the + // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction + // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000, + // and the high 16 bits will be either 0xffff or 0x0000, depending on whether + // the random value is less than the threshold. + // We then do a bit-wise AND between the mask and the original value (in 32-bit). + // We're exploiting the fact that floating point comparison is equivalent to integer + // comparison, since we're comparing unsigned integers whose top 8-bits are zero. + if (!encode_dropout_in_sign_bit + && (std::is_same::value || std::is_same::value)) { + uint16_t rnd_16[16]; + #pragma unroll + for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } + uint32_t (&rnd_32)[8] = reinterpret_cast(rnd_16); + #pragma unroll + for (int j = 0; j < 2; j++) { + Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); + // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + #pragma unroll + for (int i = 0; i < 4; i++) { + uint32_t mask; + asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); + tensor_uint32(i) &= mask; + } + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + } + } else { + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < 8; i++) { + tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); + } + Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + } + } + // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w); + // // } + } + } +} + +} // namespace flash diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/static_switch.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/static_switch.h new file mode 100644 index 000000000000..4aa847402886 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/static_switch.h @@ -0,0 +1,66 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } \ + }() + +#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM <= 32) { \ + constexpr static int kHeadDim = 32; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 64) { \ + constexpr static int kHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 96) { \ + constexpr static int kHeadDim = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 128) { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 160) { \ + constexpr static int kHeadDim = 160; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 192) { \ + constexpr static int kHeadDim = 192; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 224) { \ + constexpr static int kHeadDim = 224; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 256) { \ + constexpr static int kHeadDim = 256; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/utils.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/utils.h new file mode 100644 index 000000000000..2221a2faf3a8 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/utils.h @@ -0,0 +1,388 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint32_t relu2(const uint32_t x); + +template<> +inline __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); +#else + asm volatile( \ + "{\n" \ + "\t .reg .f16x2 sela;\n" \ + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ + "\t and.b32 %0, sela, %1;\n" + "}\n" : "=r"(res) : "r"(x), "r"(zero)); +#endif + return res; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template<> +inline __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; + asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); + return res; +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +template +inline __device__ uint32_t convert_relu2(const float2 x); + +template<> +inline __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +template<> +inline __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float2 half2_unpack(uint32_t a); + +template <> +inline __device__ float2 half2_unpack<__half>(uint32_t a) { + return __half22float2(reinterpret_cast<__half2 (&)>(a)); +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) { + return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a)); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert two half2's or bf162's into float, then take their dot product. +template +inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) { + float2 af = flash::half2_unpack(a); + float2 bf = flash::half2_unpack(b); + return af.x * bf.x + af.y * bf.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Converted two vectors of 8 half's or bf16's into float, then take their dot product. +template +inline __device__ float hmulsum8(const uint4 a, const uint4 b) { + float sum; + sum = flash::hfma2_to_float(a.x, b.x); + sum += flash::hfma2_to_float(a.y, b.y); + sum += flash::hfma2_to_float(a.z, b.z); + sum += flash::hfma2_to_float(a.w, b.w); + return sum; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ inline float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ inline T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ inline T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopy0 smem_thr_copy_A, TiledCopy1 smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } + if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } + if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +template +inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. +template +inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { + using X = Underscore; + static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); + static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; + auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) + return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), + get<0, 1>(l), + get<1, 1, 1>(l)); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void relu_(Tensor &tensor) { + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); + using value_t = typename Engine::value_type; + // HACK: this requires tensor to be "contiguous" + Tensor tensor_uint32 = recast(tensor); + #pragma unroll + for (int i = 0; i < size(tensor_uint32); ++i) { + tensor_uint32(i) = relu2(tensor_uint32(i)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction +template +inline __device__ auto convert_type_relu(Tensor const &tensor) { + using From_type = typename Engine::value_type; + static_assert(std::is_same_v || std::is_same_v); + static_assert(std::is_same_v); + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + // HACK: this requires tensor to be "contiguous" + Tensor tensor_float2 = recast(tensor); + Tensor out_uint32 = make_tensor(tensor_float2.layout()); + #pragma unroll + for (int i = 0; i < size(out_uint32); ++i) { + out_uint32(i) = convert_relu2(tensor_float2(i)); + } + Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); +#else + Tensor out = flash::convert_type(tensor); + flash::relu_(out); +#endif + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE +void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy(TiledCopy thr_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, int max_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + copy(thr_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + clear(D(_, m, _)); + } + } + // TD [2023-04-13]: Strange that the code below can cause race condition. + // I think it's because the copies are under an if statement. + // if (Is_even_K) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(thr_copy, S(_, m, _), D(_, m, _)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, _)); + // } + // } + // } else { // It's slightly faster in this case if iterate over K first + // #pragma unroll + // for (int k = 0; k < size<2>(S); ++k) { + // if (predicate_K(k)) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(thr_copy, S(_, m, k), D(_, m, k)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, k)); + // } + // } + // } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN + // if (Clear_OOB_MN || Is_even_MN) { + // clear(D(_, _, k)); + // } else { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) { + // clear(D(_, m, k)); + // } + // } + // } + // } + // } + // } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn_interface.py b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn_interface.py new file mode 100644 index 000000000000..b6f2c012bb54 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn_interface.py @@ -0,0 +1,386 @@ +import torch +import torch.nn as nn + +import col_flash_attn_2_lib as flash_attn_cuda +from einops import rearrange + + +def _get_block_size(device, head_dim, is_dropout, is_causal): + # This should match the block sizes in the CUDA kernel + assert head_dim <= 256 + major, minor = torch.cuda.get_device_capability(device) + is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100) + is_sm80 = major == 8 and minor == 0 + is_sm90 = major == 9 and minor == 0 + if head_dim <= 32: + return 128, 128 + if head_dim <= 64: + return (128, 128) if not is_dropout else (128, 64) + elif head_dim <= 96: + return (64, 64) if (is_sm8x and is_causal) else (128, 64) + elif head_dim <= 128: + if is_sm8x: + return (64, 64) if (not is_dropout and is_causal) else (128, 32) + else: + return 128, (64 if not is_dropout else 32) + elif head_dim <= 160: + if is_sm8x: + return (128, 64) if not is_causal else (64, 64) + else: + return 128, 32 + elif head_dim <= 192: + return (128, 64) if not is_dropout else (64, 64) + elif head_dim <= 224: + return (128, 64) if (is_sm80 or is_sm90) else (64, 64) + elif head_dim <= 256: + return (128, 64) if is_sm80 else (64, 64) + + +def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.flash_fwd( + q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None + ) + return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state + + +def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal, return_softmax): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_flash_fwd( + q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, + softmax_scale, False, causal, return_softmax, None + ) + # if out.isnan().any() or softmax_lse.isnan().any(): + # breakpoint() + return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state + + +class FlashAttnQKVPackedFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax): + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( + qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], dropout_p, softmax_scale, + causal=causal, return_softmax=return_softmax and dropout_p > 0 + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + +class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax): + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( + qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen = max_seqlen + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + + +class FlashAttnKVPackedFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( + q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal=causal, + return_softmax=return_softmax and dropout_p > 0 + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + + +class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, + softmax_scale, causal, return_softmax): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( + q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, + cu_seqlens_q, cu_seqlens_k, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + + +class FlashAttnFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( + q, k, v, dropout_p, softmax_scale, causal=causal, + return_softmax=return_softmax and dropout_p > 0 + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + +class FlashAttnVarlenFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, + softmax_scale, causal, return_softmax): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, + cu_seqlens_q, cu_seqlens_k, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + + +def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False): + """dropout_p should be set to 0.0 during evaluation + If Q, K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of Q, K, V. + For multi-query and grouped-query attention (MQA/GQA), please see + flash_attn_kvpacked_func and flash_attn_func. + + Arguments: + qkv: (batch_size, seqlen, 3, nheads, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs) + + +def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False): + """dropout_p should be set to 0.0 during evaluation + If K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of K, V. + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + kv: (batch_size, seqlen, 2, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, return_attn_probs) + + +def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k: (batch_size, seqlen, nheads_k, headdim) + v: (batch_size, seqlen, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs) + + +def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None, + causal=False, return_attn_probs=False): + """dropout_p should be set to 0.0 during evaluation + If Q, K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of Q, K, V. + For multi-query and grouped-query attention (MQA/GQA), please see + flash_attn_varlen_kvpacked_func and flash_attn_varlen_func. + + Arguments: + qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into qkv. + max_seqlen: int. Maximum sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnVarlenQKVPackedFunc.apply( + qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs + ) + + +def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False): + """dropout_p should be set to 0.0 during evaluation + If K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of K, V. + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnVarlenKVPackedFunc.apply( + q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal, return_attn_probs + ) + + +def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnVarlenFunc.apply( + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal, return_attn_probs + ) \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear.py b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear.py new file mode 100644 index 000000000000..718cbcbad4d8 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear.py @@ -0,0 +1,48 @@ +import torch +try: + from col_linear_lib import dense_layer_fp32_forward, dense_layer_fp16_forward, batch_dense_layer_fp16_forward + HAS_FLASH_CUDA = True +except: + HAS_FLASH_CUDA = False + print("in order to use flash-attention, make sure you install cuda kernels in op directory") + + +if HAS_FLASH_CUDA: + def linear(data, weight): + data_shape = None + if len(data.shape) > 2: + data_shape = data.shape + data = data.view(-1, data.shape[-1]) + + assert data.dtype == torch.float16, "only fp16 precision supports" + assert len(data.shape) == 2, "the shape must be 2-D" + assert len(weight.shape) == 2, "the shape must be 2-D" + + M, K = data.shape + _, N = weight.shape + + assert K == weight.shape[0], "the shape is not matchted" + + out = torch.empty((M, N), device=data.get_device(), dtype=torch.float16) + dense_layer_fp16_forward(data, weight, out, 99) + if data_shape is not None: + out = out.view(*data_shape[:-1], N) + return out + + + def batch_linear(data, weight, alibi = None, alpha = 1, beta = 0): + """ + it is equivalent to alibi.bmm(data, weight) + only supports float16 + """ + batch_count, M, K = data.shape + _, N = weight.shape + assert data.shape[-1] == weight.shape[0], "the k-dimensions must be matched" + if alibi is None: + out = torch.empty((batch_count, M, N), dtype=torch.float16, device=data.get_device()) + else: + out = alibi + + batch_dense_layer_fp16_forward(data, weight, out, alpha, beta, 99) + return out + \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear/gemm.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear/gemm.cu new file mode 100644 index 000000000000..a0c42bdb05fb --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear/gemm.cu @@ -0,0 +1,40 @@ +#include +#include +#include + +#include +#include + +void dense_layer_fp32_kernel(const float *in, const float *weight, float *out, const int M, + const int K, const int N, cublasHandle_t cublas_handle, + cudaStream_t stream, int cublasAlgo) { + const float alpha = 1.0f, beta = 0.0f; + cublasGemmEx(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K, &alpha, weight, + CUDA_R_32F, N, in, CUDA_R_32F, K, &beta, out, CUDA_R_32F, N, + CUDA_R_32F, static_cast(cublasAlgo)); +} + +void dense_layer_fp16_kernel(const __half *in, const __half *weight, __half *out, const int M, + const int K, const int N, cublasHandle_t cublas_handle, + cudaStream_t stream, int cublasAlgo) { + const __half alpha = (__half)1.0f, beta = (__half)0.0f; + cublasGemmEx(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K, &alpha, weight, + CUDA_R_16F, N, in, CUDA_R_16F, K, &beta, out, CUDA_R_16F, N, + CUDA_R_16F, static_cast(cublasAlgo)); +} + + +void cublas_Gemm_Strided_Batched_FP16_Kernel(const __half *A, const __half *B, __half *out, const int M, + const int K, const int N, const int batch_count, + cublasOperation_t trans_A, cublasOperation_t trans_B, + __half alpha, __half beta, cublasHandle_t cublas_handle, + cudaStream_t stream, int cublasAlgo) { + const int lda = (trans_A == CUBLAS_OP_N) ? K : M; + const int ldb = (trans_B == CUBLAS_OP_N) ? N : K; + + + cublasGemmStridedBatchedEx( + cublas_handle, trans_B, trans_A, N, M, K, &alpha, B, CUDA_R_16F, ldb, K * N, A, CUDA_R_16F, + lda, M * K, &beta, out, CUDA_R_16F, N, M * N, batch_count, CUDA_R_16F, + static_cast(cublasAlgo)); +} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear/linear_op.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear/linear_op.cpp new file mode 100644 index 000000000000..dc4f50ebfa55 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear/linear_op.cpp @@ -0,0 +1,206 @@ +#include +#include +#include +#include +#include + + +class CublasHandle +{ +public: + static CublasHandle& instance() + { + static CublasHandle handle; + return handle; + } + + cublasHandle_t get() const + { + return handle; + } + + CublasHandle(CublasHandle const&) = delete; + void operator=(CublasHandle const&) = delete; + +private: + cublasHandle_t handle; + + CublasHandle() + { + cublasStatus_t stat = cublasCreate(&handle); + if (stat != CUBLAS_STATUS_SUCCESS) + { + printf("cuBLAS initialization error: %d\n", stat); + exit(stat); + } + } + + ~CublasHandle() + { + cublasDestroy(handle); + } +}; + +class CudaStream { +public: + // Get the singleton instance + static CudaStream& instance() { + static CudaStream instance; + return instance; + } + + // Get the cudaStream_t + cudaStream_t get() const { + return stream; + } + +private: + // The cudaStream_t object + cudaStream_t stream; + + // Private constructor and destructor + CudaStream() { + cudaError_t err = cudaStreamCreate(&stream); + if (err != cudaSuccess) { + printf("cuda stream initialization error"); + exit(-1); + } + } + + ~CudaStream() { + cudaStreamDestroy(stream); + } + + // Delete copy and assignment constructors + CudaStream(const CudaStream&) = delete; + CudaStream& operator=(const CudaStream&) = delete; +}; + + + + +void dense_layer_fp32_kernel(const float *in, const float *weight, float *out, const int M, + const int K, const int N, cublasHandle_t cublas_handle, + cudaStream_t stream, int cublasAlgo = -1); + +void dense_layer_fp16_kernel(const __half *in, const __half *weight, __half *out, const int M, + const int K, const int N, cublasHandle_t cublas_handle, + cudaStream_t stream, int cublasAlgo = 99); + + +void cublas_Gemm_Strided_Batched_FP16_Kernel(const __half *A, const __half *B, __half *out, const int M, + const int K, const int N, const int batch_count, + cublasOperation_t trans_A, cublasOperation_t trans_B, + __half alpha, __half beta, cublasHandle_t cublas_handle, + cudaStream_t stream, int cublasAlgo = 99); + + +void dense_layer_fp32_forward(torch::Tensor& in, torch::Tensor& weight, torch::Tensor& out, int cublasAlgo) { + const int M = in.size(0); + const int K = in.size(1); + const int N = weight.size(1); + // Assumes in and weight are CUDA tensors, hence can call .data_ptr. + + cublasHandle_t handle = CublasHandle::instance().get(); + + // Now you can get a cudaStream_t like this: + cudaStream_t stream = CudaStream::instance().get(); + + dense_layer_fp32_kernel(in.data_ptr(), weight.data_ptr(), out.data_ptr(), M, K, N, handle, stream, cublasAlgo); + +} + + + +void dense_layer_fp16_forward(torch::Tensor& in, torch::Tensor& weight, torch::Tensor& out, int cublasAlgo = 99) { + const int M = in.size(0); + const int K = in.size(1); + const int N = weight.size(1); + + cublasHandle_t handle = CublasHandle::instance().get(); + + // Now you can get a cudaStream_t like this: + cudaStream_t stream = CudaStream::instance().get(); + + if(in.is_contiguous() == false){ + in = in.contiguous(); + } + + if(weight.is_contiguous() == false) { + weight = weight.contiguous(); + } + + if(out.is_contiguous() == false) { + out = out.contiguous(); + } + + dense_layer_fp16_kernel(reinterpret_cast(in.data_ptr()), + reinterpret_cast(weight.data_ptr()), + reinterpret_cast<__half*>(out.data_ptr()), + M, K, N, handle, stream, cublasAlgo); + + + +} + +void batch_dense_layer_fp16_forward(torch::Tensor& in, torch::Tensor& weight, torch::Tensor& out, float alpha, float beta, bool weight_transpose = false, int cublasAlgo = 99) { + const int batch_count = in.size(0); + const int M = in.size(1); + const int K = in.size(2); + int N = weight.size(2); + if(weight_transpose) { + N = weight.size(1); + } + + cublasHandle_t handle = CublasHandle::instance().get(); + + // Now you can get a cudaStream_t like this: + cudaStream_t stream = CudaStream::instance().get(); + + // if(in.is_contiguous() == false){ + // in = in.contiguous(); + // } + + // if(weight.is_contiguous() == false) { + // weight = weight.contiguous(); + // } + + // if(out.is_contiguous() == false) { + // out = out.contiguous(); + // } + if(weight_transpose == false) { + cublas_Gemm_Strided_Batched_FP16_Kernel(reinterpret_cast(in.data_ptr()), + reinterpret_cast(weight.data_ptr()), + reinterpret_cast<__half*>(out.data_ptr()), + M, K, N, batch_count, + CUBLAS_OP_N, CUBLAS_OP_N, + (__half)alpha, (__half)beta, handle, stream, cublasAlgo + ); + }else { + cublas_Gemm_Strided_Batched_FP16_Kernel(reinterpret_cast(in.data_ptr()), + reinterpret_cast(weight.data_ptr()), + reinterpret_cast<__half*>(out.data_ptr()), + M, K, N, batch_count, + CUBLAS_OP_N, CUBLAS_OP_T, + (__half)alpha, (__half)beta, handle, stream, cublasAlgo + ); + } +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("dense_layer_fp32_forward", + &dense_layer_fp32_forward, + "fp32 forward of dense layer"); + + m.def("dense_layer_fp16_forward", + &dense_layer_fp16_forward, + "fp16 forward of dense layer." + ); + + m.def("batch_dense_layer_fp16_forward", + &batch_dense_layer_fp16_forward, + "fp16 forward of batch gemm" + ); + +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp new file mode 100644 index 000000000000..565d134cdedf --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp @@ -0,0 +1,15 @@ +#include + +void rotary_embedding_neox( + torch::Tensor& positions, + torch::Tensor& query, + torch::Tensor& key, + int head_size, + torch::Tensor& cos_sin_cache); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "rotary_embedding_neox", + &rotary_embedding_neox, + "Apply GPT-NeoX style rotary embedding to query and key"); +} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu new file mode 100644 index 000000000000..1f0f8968619b --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu @@ -0,0 +1,86 @@ +#include +#include + + +template +__global__ void rotary_embedding_neox_kernel( + const int64_t* __restrict__ positions, // [num_tokens] + scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, + const int stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + + const int embed_dim = rot_dim / 2; + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int token_head = token_idx * stride + head_idx * head_size; + + const int rot_offset = i % embed_dim; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; + + const int out_x = token_idx * stride + head_idx * head_size + x_index; + const int out_y = token_idx * stride + head_idx * head_size + y_index; + + const scalar_t cos = __ldg(cache_ptr + x_index); + const scalar_t sin = __ldg(cache_ptr + y_index); + + const scalar_t q_x = query[token_head + x_index]; + const scalar_t q_y = query[token_head + y_index]; + query[out_x] = q_x * cos - q_y * sin; + query[out_y] = q_y * cos + q_x * sin; + + if (head_idx < num_kv_heads) { + const scalar_t k_x = key[token_head + x_index]; + const scalar_t k_y = key[token_head + y_index]; + key[out_x] = k_x * cos - k_y * sin; + key[out_y] = k_y * cos + k_x * sin; + } + } +} + + +void rotary_embedding_neox( + torch::Tensor& positions, // [num_tokens] + torch::Tensor& query, // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [num_tokens, num_kv_heads * head_size] + int head_size, + torch::Tensor& cos_sin_cache) // [max_position, rot_dim] +{ + int num_tokens = query.size(0); + int rot_dim = cos_sin_cache.size(1); + int num_heads = query.size(1) / head_size; + int num_kv_heads = key.size(1) / head_size; + int stride = query.stride(0); + TORCH_CHECK(stride == key.stride(0)); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + query.scalar_type(), + "rotary_embedding_neox", + [&] { + rotary_embedding_neox_kernel<<>>( + positions.data_ptr(), + query.data_ptr(), + key.data_ptr(), + cos_sin_cache.data_ptr(), + rot_dim, + stride, + num_heads, + num_kv_heads, + head_size); + }); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/fused_softmax.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/fused_softmax.cpp new file mode 100644 index 000000000000..ffc68e6c731c --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/fused_softmax.cpp @@ -0,0 +1,46 @@ +/* coding=utf-8 + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor); + +torch::Tensor fwd( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) { + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); + + return fwd_cuda(input, mask, scale_factor); +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("scaled_masked_softmax_forward", + &fwd, + "self-multihead attention scaled masked softmax(forward)"); +} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/scaled_masked_softmax.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/scaled_masked_softmax.h new file mode 100644 index 000000000000..d923ade203dc --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/scaled_masked_softmax.h @@ -0,0 +1,338 @@ +/* coding=utf-8 + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Explicit masking + */ +template +__global__ void scaled_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const acc_t scale, + int micro_batch_size, + int element_count, + int pad_batches) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; + int pad_first_batch = 0; + if (pad_batches != 1) { // bert style + pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; + } else { // gpt2 style + pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + } + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i*element_count+it*WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + copy_vector(temp_mask, mask + itr_idx); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + // compute scale value to account for full mask + acc_t scale_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0; + } + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] * scale_value[i]/ sum[i]; + } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } + } + } +} + + +} // end of anonymous namespace + +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + + return batches_per_block; +} + +template +void dispatch_scaled_masked_softmax_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const input_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads, + int pad_batches) +{ + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 8192 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); + dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 1: // 2 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 2: // 4 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 3: // 8 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 4: // 16 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 5: // 32 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 6: // 64 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 7: // 128 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 8: // 256 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 9: // 512 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 10: // 1024 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 11: // 2048 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 12: // 4096 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 13: // 8192 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + default: + break; + } + } +} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/scaled_masked_softmax_cuda.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/scaled_masked_softmax_cuda.cu new file mode 100644 index 000000000000..de3547671c5f --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/scaled_masked_softmax_cuda.cu @@ -0,0 +1,78 @@ +/* coding=utf-8 + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "scaled_masked_softmax.h" +#include "type_shim.h" + + + +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); +} + + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) +{ + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = input.size(0); + const int pad_batches = mask.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + TORCH_INTERNAL_ASSERT(key_seq_len <= 8192); + TORCH_INTERNAL_ASSERT(query_seq_len > 1); + TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); + TORCH_INTERNAL_ASSERT(mask.size(1) == 1); + TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); + TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* mask_ptr = static_cast(mask.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_masked_softmax_forward", + dispatch_scaled_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + reinterpret_cast(mask_ptr), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads, + pad_batches + ); + ); + return softmax_results; +} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/type_shim.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/type_shim.h new file mode 100644 index 000000000000..815ec7ec8896 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/type_shim.h @@ -0,0 +1,20 @@ +#include + +#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ +switch(TYPE) \ +{ \ +case at::ScalarType::Half: \ + { \ +using scalar_t = at::Half; \ +__VA_ARGS__; \ +break; \ + } \ +case at::ScalarType::BFloat16: \ + { \ +using scalar_t = at::BFloat16; \ +__VA_ARGS__; \ +break; \ + } \ +default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ +} diff --git a/colossalai/shardformer/setup.py b/colossalai/shardformer/setup.py new file mode 100644 index 000000000000..508ea462503f --- /dev/null +++ b/colossalai/shardformer/setup.py @@ -0,0 +1,233 @@ +# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py +import sys +import warnings +import os +import re +import ast +from pathlib import Path +from packaging.version import parse, Version + +from setuptools import setup, find_packages +import subprocess + +import torch +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME +from torch.utils import cpp_extension + +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_cuda_torch_binary_vs_bare_metal(cuda_dir): + raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) + torch_binary_version = parse(torch.version.cuda) + + print("\nCompiling cuda extensions with") + print(raw_output + "from " + cuda_dir + "/bin\n") + + if (bare_metal_version != torch_binary_version): + raise RuntimeError( + "Cuda extensions are being compiled with a version of Cuda that does " + "not match the version used to compile Pytorch binaries. " + "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) + + "In some cases, a minor-version mismatch will not cause later errors: " + "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " + "You can try commenting out this check (at your own risk)." + ) + + +def raise_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + raise RuntimeError( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.2"): + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args + + +if not torch.cuda.is_available(): + # https://github.com/NVIDIA/apex/issues/486 + # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), + # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). + print( + "\nWarning: Torch did not find available GPUs on this system.\n", + "If your intention is to cross-compile, this is not an error.\n" + "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" + "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" + "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" + "If you wish to cross-compile for a single specific architecture,\n" + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', + ) + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.8"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" + elif bare_metal_version >= Version("11.1"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" + elif bare_metal_version == Version("11.0"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" + else: + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" + + +print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) +TORCH_MAJOR = int(torch.__version__.split(".")[0]) +TORCH_MINOR = int(torch.__version__.split(".")[1]) + +cmdclass = {} +ext_modules = [] + +# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h +# See https://github.com/pytorch/pytorch/pull/70650 +generator_flag = [] +torch_dir = torch.__path__[0] +if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + generator_flag = ["-DOLD_GENERATOR_PATH"] + +raise_if_cuda_home_none("flash_attn") +# Check, if CUDA11 is installed for compute capability 8.0 +cc_flag = [] +_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) +if bare_metal_version < Version("11.0"): + raise RuntimeError("FlashAttention is only supported on CUDA 11 and above") +# cc_flag.append("-gencode") +# cc_flag.append("arch=compute_75,code=sm_75") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_80,code=sm_80") +if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + +setup( + name='colossal-cuda-kernels', + ext_modules=[ + CUDAExtension( + name='col_fused_softmax_lib', + sources=['softmax/fused_softmax.cpp', 'softmax/scaled_masked_softmax_cuda.cu'], + extra_compile_args={ + 'cxx': ['-O3',], + 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) + } + ), + + CUDAExtension( + name="col_pos_encoding_ops", + sources=["rotary_embedding/pos_encoding.cpp", "rotary_embedding/pos_encoding_kernels.cu"], + extra_compile_args={ + 'cxx': ['-O3',], + 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) + }, + ), + CUDAExtension( + name="col_flash_attn_2_lib", + sources=[ + "flash_attn/flash_api.cpp", + "flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"] + generator_flag, + "nvcc": append_nvcc_threads( + [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v", + "-lineinfo" + ] + + generator_flag + + cc_flag + ), + }, + include_dirs=[ + Path(this_dir) / 'flash_attn' , + Path(this_dir) / 'flash_attn' / 'src', + Path(this_dir) / 'cutlass' / 'include', + ], + ), + + CUDAExtension( + name="col_linear_lib", + sources=[ + "linear/linear_op.cpp", + "linear/gemm.cu", + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"] + generator_flag, + "nvcc": append_nvcc_threads( + [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v", + "-lineinfo" + "-lcublas" + ] + + generator_flag + + cc_flag + ), + }, + include_dirs=[ + Path(this_dir) / 'linear' , + Path(this_dir) / 'cutlass' / 'include', + ], + ), + + ], + cmdclass={ + 'build_ext': BuildExtension +}) + + + + + + From 68a47359da83bc5673878ca95ac67c45fdfdc632 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 10 Aug 2023 17:38:25 +0800 Subject: [PATCH 02/31] update cutlass --- .gitmodules | 3 + .../bert_padding.py | 0 colossalai/kernel/cuda_native/csrc/cutlass | 1 + .../flash_attn_interface.py | 3 +- .../attention_infer_kernels => }/linear.py | 0 colossalai/kernel/cuda_native/setup.py | 233 ++++++++++++++++++ 6 files changed, 238 insertions(+), 2 deletions(-) rename colossalai/kernel/cuda_native/{csrc/attention_infer_kernels => }/bert_padding.py (100%) create mode 160000 colossalai/kernel/cuda_native/csrc/cutlass rename colossalai/kernel/cuda_native/{csrc/attention_infer_kernels => }/flash_attn_interface.py (99%) rename colossalai/kernel/cuda_native/{csrc/attention_infer_kernels => }/linear.py (100%) create mode 100644 colossalai/kernel/cuda_native/setup.py diff --git a/.gitmodules b/.gitmodules index 2f1c34298a50..9c4a5ae34744 100644 --- a/.gitmodules +++ b/.gitmodules @@ -5,3 +5,6 @@ [submodule "examples/tutorial/fastfold/FastFold"] path = examples/tutorial/fastfold/FastFold url = https://github.com/hpcaitech/FastFold +[submodule "colossalai/kernel/cuda_native/csrc/cutlass"] + path = colossalai/kernel/cuda_native/csrc/cutlass + url = https://github.com/NVIDIA/cutlass diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/bert_padding.py b/colossalai/kernel/cuda_native/bert_padding.py similarity index 100% rename from colossalai/kernel/cuda_native/csrc/attention_infer_kernels/bert_padding.py rename to colossalai/kernel/cuda_native/bert_padding.py diff --git a/colossalai/kernel/cuda_native/csrc/cutlass b/colossalai/kernel/cuda_native/csrc/cutlass new file mode 160000 index 000000000000..c4f6b8c6bc94 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/cutlass @@ -0,0 +1 @@ +Subproject commit c4f6b8c6bc94ff69048492fb34df0dfaf1983933 diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn_interface.py b/colossalai/kernel/cuda_native/flash_attn_interface.py similarity index 99% rename from colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn_interface.py rename to colossalai/kernel/cuda_native/flash_attn_interface.py index b6f2c012bb54..bf57ea7e9871 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn_interface.py +++ b/colossalai/kernel/cuda_native/flash_attn_interface.py @@ -1,9 +1,8 @@ import torch import torch.nn as nn -import col_flash_attn_2_lib as flash_attn_cuda from einops import rearrange - +import col_flash_attn_2_lib as flash_attn_cuda def _get_block_size(device, head_dim, is_dropout, is_causal): # This should match the block sizes in the CUDA kernel diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear.py b/colossalai/kernel/cuda_native/linear.py similarity index 100% rename from colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear.py rename to colossalai/kernel/cuda_native/linear.py diff --git a/colossalai/kernel/cuda_native/setup.py b/colossalai/kernel/cuda_native/setup.py new file mode 100644 index 000000000000..e8b11c4d8bf2 --- /dev/null +++ b/colossalai/kernel/cuda_native/setup.py @@ -0,0 +1,233 @@ +# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py +import sys +import warnings +import os +import re +import ast +from pathlib import Path +from packaging.version import parse, Version + +from setuptools import setup, find_packages +import subprocess + +import torch +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME +from torch.utils import cpp_extension + +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_cuda_torch_binary_vs_bare_metal(cuda_dir): + raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) + torch_binary_version = parse(torch.version.cuda) + + print("\nCompiling cuda extensions with") + print(raw_output + "from " + cuda_dir + "/bin\n") + + if (bare_metal_version != torch_binary_version): + raise RuntimeError( + "Cuda extensions are being compiled with a version of Cuda that does " + "not match the version used to compile Pytorch binaries. " + "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) + + "In some cases, a minor-version mismatch will not cause later errors: " + "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " + "You can try commenting out this check (at your own risk)." + ) + + +def raise_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + raise RuntimeError( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.2"): + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args + + +if not torch.cuda.is_available(): + # https://github.com/NVIDIA/apex/issues/486 + # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), + # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). + print( + "\nWarning: Torch did not find available GPUs on this system.\n", + "If your intention is to cross-compile, this is not an error.\n" + "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" + "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" + "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" + "If you wish to cross-compile for a single specific architecture,\n" + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', + ) + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.8"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" + elif bare_metal_version >= Version("11.1"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" + elif bare_metal_version == Version("11.0"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" + else: + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" + + +print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) +TORCH_MAJOR = int(torch.__version__.split(".")[0]) +TORCH_MINOR = int(torch.__version__.split(".")[1]) + +cmdclass = {} +ext_modules = [] + +# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h +# See https://github.com/pytorch/pytorch/pull/70650 +generator_flag = [] +torch_dir = torch.__path__[0] +if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + generator_flag = ["-DOLD_GENERATOR_PATH"] + +raise_if_cuda_home_none("flash_attn") +# Check, if CUDA11 is installed for compute capability 8.0 +cc_flag = [] +_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) +if bare_metal_version < Version("11.0"): + raise RuntimeError("FlashAttention is only supported on CUDA 11 and above") +# cc_flag.append("-gencode") +# cc_flag.append("arch=compute_75,code=sm_75") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_80,code=sm_80") +if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + +setup( + name='colossal-cuda-kernels', + ext_modules=[ + CUDAExtension( + name='col_fused_softmax_lib', + sources=['softmax/fused_softmax.cpp', 'softmax/scaled_masked_softmax_cuda.cu'], + extra_compile_args={ + 'cxx': ['-O3',], + 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) + } + ), + + CUDAExtension( + name="col_pos_encoding_ops", + sources=["rotary_embedding/pos_encoding.cpp", "rotary_embedding/pos_encoding_kernels.cu"], + extra_compile_args={ + 'cxx': ['-O3',], + 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) + }, + ), + CUDAExtension( + name="col_flash_attn_2_lib", + sources=[ + "flash_attn/flash_api.cpp", + "flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"] + generator_flag, + "nvcc": append_nvcc_threads( + [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v", + "-lineinfo" + ] + + generator_flag + + cc_flag + ), + }, + include_dirs=[ + Path(this_dir) / 'csrc' /'flash_attn' , + Path(this_dir) / 'csrc' /'flash_attn' / 'src', + Path(this_dir) / 'csrc'/'cutlass' / 'include', + ], + ), + + CUDAExtension( + name="col_linear_lib", + sources=[ + "linear/linear_op.cpp", + "linear/gemm.cu", + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"] + generator_flag, + "nvcc": append_nvcc_threads( + [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v", + "-lineinfo" + "-lcublas" + ] + + generator_flag + + cc_flag + ), + }, + include_dirs=[ + Path(this_dir) / 'csrc' /'linear' , + Path(this_dir) / 'csrc'/'cutlass' / 'include', + ], + ), + + ], + cmdclass={ + 'build_ext': BuildExtension +}) + + + + + + From c70dcbeec11ecfde394a89303958d82c1c56d80c Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 10 Aug 2023 17:38:43 +0800 Subject: [PATCH 03/31] update --- colossalai/shardformer/setup.py | 233 -------------------------------- 1 file changed, 233 deletions(-) delete mode 100644 colossalai/shardformer/setup.py diff --git a/colossalai/shardformer/setup.py b/colossalai/shardformer/setup.py deleted file mode 100644 index 508ea462503f..000000000000 --- a/colossalai/shardformer/setup.py +++ /dev/null @@ -1,233 +0,0 @@ -# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import sys -import warnings -import os -import re -import ast -from pathlib import Path -from packaging.version import parse, Version - -from setuptools import setup, find_packages -import subprocess - -import torch -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME -from torch.utils import cpp_extension - -with open("README.md", "r", encoding="utf-8") as fh: - long_description = fh.read() - - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) - torch_binary_version = parse(torch.version.cuda) - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_version != torch_binary_version): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pytorch binaries. " - "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.2"): - return nvcc_extra_args + ["--threads", "4"] - return nvcc_extra_args - - -if not torch.cuda.is_available(): - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" - "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" - "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.8"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" - elif bare_metal_version >= Version("11.1"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" - elif bare_metal_version == Version("11.0"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" - - -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) - -cmdclass = {} -ext_modules = [] - -# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h -# See https://github.com/pytorch/pytorch/pull/70650 -generator_flag = [] -torch_dir = torch.__path__[0] -if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - -raise_if_cuda_home_none("flash_attn") -# Check, if CUDA11 is installed for compute capability 8.0 -cc_flag = [] -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) -if bare_metal_version < Version("11.0"): - raise RuntimeError("FlashAttention is only supported on CUDA 11 and above") -# cc_flag.append("-gencode") -# cc_flag.append("arch=compute_75,code=sm_75") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") -if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - -setup( - name='colossal-cuda-kernels', - ext_modules=[ - CUDAExtension( - name='col_fused_softmax_lib', - sources=['softmax/fused_softmax.cpp', 'softmax/scaled_masked_softmax_cuda.cu'], - extra_compile_args={ - 'cxx': ['-O3',], - 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) - } - ), - - CUDAExtension( - name="col_pos_encoding_ops", - sources=["rotary_embedding/pos_encoding.cpp", "rotary_embedding/pos_encoding_kernels.cu"], - extra_compile_args={ - 'cxx': ['-O3',], - 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) - }, - ), - CUDAExtension( - name="col_flash_attn_2_lib", - sources=[ - "flash_attn/flash_api.cpp", - "flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", - "flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", - "flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", - "flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", - "flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", - "flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", - "flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", - "flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", - "flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", - "flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", - "flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", - "flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", - "flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu", - "flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu", - "flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", - "flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", - ], - extra_compile_args={ - "cxx": ["-O3", "-std=c++17"] + generator_flag, - "nvcc": append_nvcc_threads( - [ - "-O3", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "--ptxas-options=-v", - "-lineinfo" - ] - + generator_flag - + cc_flag - ), - }, - include_dirs=[ - Path(this_dir) / 'flash_attn' , - Path(this_dir) / 'flash_attn' / 'src', - Path(this_dir) / 'cutlass' / 'include', - ], - ), - - CUDAExtension( - name="col_linear_lib", - sources=[ - "linear/linear_op.cpp", - "linear/gemm.cu", - ], - extra_compile_args={ - "cxx": ["-O3", "-std=c++17"] + generator_flag, - "nvcc": append_nvcc_threads( - [ - "-O3", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "--ptxas-options=-v", - "-lineinfo" - "-lcublas" - ] - + generator_flag - + cc_flag - ), - }, - include_dirs=[ - Path(this_dir) / 'linear' , - Path(this_dir) / 'cutlass' / 'include', - ], - ), - - ], - cmdclass={ - 'build_ext': BuildExtension -}) - - - - - - From a8938bae4d518d741e7b4a636cffe822ee50d9d5 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 10 Aug 2023 18:11:11 +0800 Subject: [PATCH 04/31] adding kernels --- colossalai/kernel/cuda_native/setup.py | 57 ++++++++++++++------------ 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/colossalai/kernel/cuda_native/setup.py b/colossalai/kernel/cuda_native/setup.py index e8b11c4d8bf2..ba6f4a6fdbd1 100644 --- a/colossalai/kernel/cuda_native/setup.py +++ b/colossalai/kernel/cuda_native/setup.py @@ -15,9 +15,6 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME from torch.utils import cpp_extension -with open("README.md", "r", encoding="utf-8") as fh: - long_description = fh.read() - # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) @@ -125,7 +122,10 @@ def append_nvcc_threads(nvcc_extra_args): ext_modules=[ CUDAExtension( name='col_fused_softmax_lib', - sources=['softmax/fused_softmax.cpp', 'softmax/scaled_masked_softmax_cuda.cu'], + sources=[ + 'csrc/attention_infer_kernels/softmax/fused_softmax.cpp', + 'csrc/attention_infer_kernels/softmax/scaled_masked_softmax_cuda.cu' + ], extra_compile_args={ 'cxx': ['-O3',], 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) @@ -134,7 +134,10 @@ def append_nvcc_threads(nvcc_extra_args): CUDAExtension( name="col_pos_encoding_ops", - sources=["rotary_embedding/pos_encoding.cpp", "rotary_embedding/pos_encoding_kernels.cu"], + sources=[ + "csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp", + "csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu" + ], extra_compile_args={ 'cxx': ['-O3',], 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) @@ -143,23 +146,23 @@ def append_nvcc_threads(nvcc_extra_args): CUDAExtension( name="col_flash_attn_2_lib", sources=[ - "flash_attn/flash_api.cpp", - "flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", - "flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", - "flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", - "flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", - "flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", - "flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", - "flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", - "flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", - "flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", - "flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", - "flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", - "flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", - "flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu", - "flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu", - "flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", - "flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", + "csrc/attention_infer_kernels/flash_attn/flash_api.cpp", + "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", + "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", + "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", + "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", + "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", + "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", + "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", + "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", + "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", + "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", + "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", + "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", + "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu", + "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu", + "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", + "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", ], extra_compile_args={ "cxx": ["-O3", "-std=c++17"] + generator_flag, @@ -182,8 +185,8 @@ def append_nvcc_threads(nvcc_extra_args): ), }, include_dirs=[ - Path(this_dir) / 'csrc' /'flash_attn' , - Path(this_dir) / 'csrc' /'flash_attn' / 'src', + Path(this_dir) / 'csrc'/'attention_infer_kernels'/'flash_attn' , + Path(this_dir) / 'csrc'/ 'attention_infer_kernels'/'flash_attn' / 'src', Path(this_dir) / 'csrc'/'cutlass' / 'include', ], ), @@ -191,8 +194,8 @@ def append_nvcc_threads(nvcc_extra_args): CUDAExtension( name="col_linear_lib", sources=[ - "linear/linear_op.cpp", - "linear/gemm.cu", + "csrc/attention_infer_kernels/linear/linear_op.cpp", + "csrc/attention_infer_kernels/linear/gemm.cu", ], extra_compile_args={ "cxx": ["-O3", "-std=c++17"] + generator_flag, @@ -216,7 +219,7 @@ def append_nvcc_threads(nvcc_extra_args): ), }, include_dirs=[ - Path(this_dir) / 'csrc' /'linear' , + Path(this_dir) / 'csrc'/'attention_infer_kernels' /'linear' , Path(this_dir) / 'csrc'/'cutlass' / 'include', ], ), From b9b4396cd429ddeb7e52372a23fc50c0c16ba4ec Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 10 Aug 2023 18:15:53 +0800 Subject: [PATCH 05/31] delete useless files --- .../flash_attention.py | 52 ------------------- 1 file changed, 52 deletions(-) delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attention.py diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attention.py b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attention.py deleted file mode 100644 index d70703f65f6a..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attention.py +++ /dev/null @@ -1,52 +0,0 @@ -try: - from col_flash_attn_2_lib import flash_fwd, varlen_flash_fwd - HAS_FLASH_CUDA = True -except: - HAS_FLASH_CUDA = False - print("in order to use flash-attention, make sure you install cuda kernels in op directory") - - -if HAS_FLASH_CUDA: - def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_fwd( - q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None - ) - return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state - - def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, causal, return_softmax): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = varlen_flash_fwd( - q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, - softmax_scale, False, causal, return_softmax, None - ) - - return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state - - - def flash_attention_fwd(qkv, scale, causal = True, return_softmax = False): - assert qkv.is_contiguous() - batches = qkv.shape[0] - d_model = qkv.shape[-1] // 3 - num_of_heads = d_model // head_size - - q = qkv[:, :, :d_model] - k = qkv[:, :, d_model:d_model * 2] - v = qkv[:, :, d_model * 2:] - q = q.view(batches, -1, num_of_heads, head_size) - k = k.view(batches, -1, num_of_heads, head_size) - v = v.view(batches, -1, num_of_heads, head_size) - - out_flash, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(q, k, v, 0, - softmax_scale = scale, - causal = causal, - return_softmax = return_softmax - ) - - if return_softmax: - return out_flash, softmax_lse - else: - return out_flash From 7a236e3cd72eb7f983fba52e1bb9bd3f5aecb2a9 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 10 Aug 2023 18:20:53 +0800 Subject: [PATCH 06/31] clean codes --- .../src/flash_fwd_hdim128_bf16_sm80.cu | 9 ------- .../src/flash_fwd_hdim128_fp16_sm80.cu | 21 --------------- .../src/flash_fwd_hdim160_bf16_sm80.cu | 7 ----- .../src/flash_fwd_hdim160_fp16_sm80.cu | 17 ------------ .../src/flash_fwd_hdim192_bf16_sm80.cu | 7 ----- .../src/flash_fwd_hdim192_fp16_sm80.cu | 17 ------------ .../src/flash_fwd_hdim32_fp16_sm80.cu | 13 ---------- .../src/flash_fwd_hdim64_bf16_sm80.cu | 9 ------- .../src/flash_fwd_hdim64_fp16_sm80.cu | 16 ------------ .../src/flash_fwd_hdim96_bf16_sm80.cu | 7 ----- .../src/flash_fwd_hdim96_fp16_sm80.cu | 14 ---------- .../flash_attn/src/flash_fwd_kernel.h | 12 --------- .../src/flash_fwd_launch_template.h | 26 +------------------ 13 files changed, 1 insertion(+), 174 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu index 654400a74919..2c8f75b17973 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu @@ -4,15 +4,6 @@ #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// if (params.p_dropout == 1.f) { -// run_flash_fwd, false>(params, stream); -// } else { -// run_flash_fwd, true>(params, stream); -// } -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu index 5b7254a918d7..eca6a06632bd 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu @@ -4,27 +4,6 @@ #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// if (params.p_dropout == 1.f) { -// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k -// run_flash_fwd, false>(params, stream); -// // run_flash_fwd, false>(params, stream); -// // run_flash_fwd, false>(params, stream); -// // run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// // 1st ones are good for H100, A100 -// // 2nd one is good for A6000 bc we get slightly better occupancy -// } else { -// run_flash_fwd, true>(params, stream); -// run_flash_fwd, true>(params, stream); -// run_flash_fwd, true>(params, stream); -// // 1st one is good for H100, A100, A6000 -// } -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu index 6a9d60c39156..898cd9c4b6dd 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu @@ -4,13 +4,6 @@ #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// }); -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim160(params, stream); diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu index 6c40a164d6d8..19adb4b28deb 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu @@ -4,23 +4,6 @@ #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest. -// // For A100, H100, 1st is fastest. -// }); -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim160(params, stream); diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu index d2f4cba71528..130bf71d0c5d 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu @@ -4,13 +4,6 @@ #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// }); -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu index 2875c92660a9..32cff41a55f0 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu @@ -4,23 +4,6 @@ #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// // This one is slightly faster for causal? -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// }); -// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout -// // For A6000, 1st is faster when causal, 3rd is faster when not causal -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu index 91e6331e90cc..b20a1781560e 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu @@ -4,19 +4,6 @@ #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// // For dropout there might be a lot of register spilling? -// // These two are very slow due to register spilling -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// // This one is slightly slower -// // run_flash_fwd>(params, stream); -// }); -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu index fffcbebb5d98..12b4552c2073 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu @@ -4,15 +4,6 @@ #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// if (params.p_dropout == 1.f) { -// run_flash_fwd, false>(params, stream); -// } else { -// run_flash_fwd, true>(params, stream); -// } -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu index 01bd1716720b..dd20bc67282b 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu @@ -4,22 +4,6 @@ #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// if (params.p_dropout == 1.f) { -// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower -// // Using block size (64 x 256) is 27% slower for seqlen=2k -// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling -// run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// run_flash_fwd, false>(params, stream); -// } else { -// run_flash_fwd, true>(params, stream); -// run_flash_fwd, true>(params, stream); -// run_flash_fwd, true>(params, stream); -// } -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu index b0b27db59600..7039334a3ae9 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu @@ -4,13 +4,6 @@ #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::bfloat16_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// }); -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu index 820b63cbbfd5..a8420bd02945 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu @@ -4,20 +4,6 @@ #include "flash_fwd_launch_template.h" -// template<> -// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { -// using elem_type = cutlass::half_t; -// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// // This 3rd one is good for H100, and A100, A6000 -// run_flash_fwd, Is_dropout>(params, stream); -// run_flash_fwd, Is_dropout>(params, stream); -// // These two are always slower -// // run_flash_fwd>(params, stream); -// // run_flash_fwd>(params, stream); -// }); -// } template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); } \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_kernel.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_kernel.h index 6e7364776b22..7539d71dfc50 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_kernel.h +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_kernel.h @@ -236,18 +236,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Construct identity layout for sQ and sK Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) - // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) - // if (cute::thread0()) { - // print(tScQ.layout()); printf("\n"); - // for (int i = 0; i < size(tScQ); ++i) { - // printf("%d ", get<0>(tScQ(i))); - // } - // printf("\n"); - // for (int i = 0; i < size(tScQ); ++i) { - // printf("%d ", get<1>(tScQ(i))); - // } - // printf("\n"); - // } // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_launch_template.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_launch_template.h index f48186aeb333..1f205961f7b3 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_launch_template.h +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_launch_template.h @@ -72,13 +72,8 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { // Using block size (64 x 256) is 27% slower for seqlen=2k // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); } else { run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); } }); }); @@ -101,11 +96,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { } else { run_flash_fwd, Is_dropout, Is_causal>(params, stream); } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // These two are always slower - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); + }); }); } @@ -139,9 +130,6 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { // 2nd one is good for A6000 bc we get slightly better occupancy } else { run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); } }); }); @@ -166,13 +154,6 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { } else { run_flash_fwd, Is_dropout, Is_causal>(params, stream); } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); }); }); } @@ -187,11 +168,6 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { } else { run_flash_fwd, Is_dropout, Is_causal>(params, stream); } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); }); }); } From 54ac1e16b22a3f25be674bf627d54e6c008d8972 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 10 Aug 2023 18:48:51 +0800 Subject: [PATCH 07/31] added cuda test --- tests/test_kernels/cuda/test_softmax.py | 71 +++++++++++++++++++ .../{ => triton}/test_self_attention.py | 0 .../test_kernels/{ => triton}/test_softmax.py | 0 3 files changed, 71 insertions(+) create mode 100644 tests/test_kernels/cuda/test_softmax.py rename tests/test_kernels/{ => triton}/test_self_attention.py (100%) rename tests/test_kernels/{ => triton}/test_softmax.py (100%) diff --git a/tests/test_kernels/cuda/test_softmax.py b/tests/test_kernels/cuda/test_softmax.py new file mode 100644 index 000000000000..7881c9d3811c --- /dev/null +++ b/tests/test_kernels/cuda/test_softmax.py @@ -0,0 +1,71 @@ +import os +import numpy as np + +import torch +from torch.nn import functional as F +from col_fused_softmax_lib import scaled_masked_softmax_forward + +def get_latency_for_cuda(func, data, mask, scale): + starter, ender = torch.cuda.Event( + enable_timing=True), torch.cuda.Event(enable_timing=True) + repetitions = 300 + + for i in range(10): + func(data, mask, scale) + + timings = np.zeros((repetitions, 1)) + with torch.no_grad(): + for rep in range(repetitions): + starter.record() + func(data, mask, 1) + ender.record() + # WAIT FOR GPU SYNC + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + timings[rep] = curr_time + + mean_syn = np.sum(timings) / repetitions + return mean_syn + + +def get_latency_for_torch(func, data): + starter, ender = torch.cuda.Event( + enable_timing=True), torch.cuda.Event(enable_timing=True) + repetitions = 300 + + for i in range(10): + func(data, dim=-1) + + timings = np.zeros((repetitions, 1)) + with torch.no_grad(): + for rep in range(repetitions): + starter.record() + func(data, dim=-1) + ender.record() + # WAIT FOR GPU SYNC + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + timings[rep] = curr_time + + mean_syn = np.sum(timings) / repetitions + return mean_syn + +def test(): + size = (17, 3, 1024, 256) + data = torch.randn(size = size, device="cuda", dtype=torch.float16) + mask = torch.zeros(size = (17, 1, 1024, 256), device="cuda", dtype=torch.uint8) + + out_cuda = scaled_masked_softmax_forward(data, mask, 1) + + out_torch = F.softmax(data, dim = -1) + + torch.allclose(out_cuda.cpu(), out_torch.cpu(), rtol=1e-5, atol=1e-5) + + latency_1 = get_latency_for_cuda(scaled_masked_softmax_forward, data, mask, 1) + latency_2 = get_latency_for_torch(F.softmax, data) + print("the cuda implementation is {} ms".format(str(latency_1))) + print("the original torch cuda implementation is {} ms".format(str(latency_2))) + + +if __name__ == "__main__": + test() \ No newline at end of file diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_kernels/triton/test_self_attention.py similarity index 100% rename from tests/test_kernels/test_self_attention.py rename to tests/test_kernels/triton/test_self_attention.py diff --git a/tests/test_kernels/test_softmax.py b/tests/test_kernels/triton/test_softmax.py similarity index 100% rename from tests/test_kernels/test_softmax.py rename to tests/test_kernels/triton/test_softmax.py From e48a1431ce2fceecb7a64da7e9c32332584bce1e Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 14 Aug 2023 17:44:51 +0800 Subject: [PATCH 08/31] added tests --- .../layernorm/layernorm.cpp | 14 ++++ .../layernorm/layernorm_kernels.cu | 62 +++++++++++++++ .../layernorm/reduction_utils.cuh | 50 ++++++++++++ .../rotary_embedding/pos_encoding.cpp | 3 + .../rotary_embedding/pos_encoding_kernels.cu | 3 + colossalai/kernel/cuda_native/setup.py | 16 ++++ tests/test_kernels/cuda/test_rmsnorm.py | 57 ++++++++++++++ tests/test_kernels/cuda/test_softmax.py | 78 ++++--------------- 8 files changed, 222 insertions(+), 61 deletions(-) create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/layernorm/layernorm.cpp create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/layernorm/layernorm_kernels.cu create mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/layernorm/reduction_utils.cuh create mode 100644 tests/test_kernels/cuda/test_rmsnorm.py diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/layernorm/layernorm.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/layernorm/layernorm.cpp new file mode 100644 index 000000000000..749ca5f92154 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/layernorm/layernorm.cpp @@ -0,0 +1,14 @@ +#include + +void rms_norm( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& weight, + float epsilon); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "rms_norm", + &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); +} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/layernorm/layernorm_kernels.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/layernorm/layernorm_kernels.cu new file mode 100644 index 000000000000..7f7889df3ec4 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/layernorm/layernorm_kernels.cu @@ -0,0 +1,62 @@ +/*This code from Vllm : https://github.com/vllm-project/vllm + * with minor changes. */ + +#include +#include + +#include "reduction_utils.cuh" + +template +__global__ void rms_norm_kernel( + scalar_t* __restrict__ out, // [num_tokens, hidden_size] + const scalar_t* __restrict__ input, // [num_tokens, hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + const float x = (float) input[blockIdx.x * hidden_size + idx]; + variance += x * x; + } + variance = blockReduceSum(variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float) input[blockIdx.x * hidden_size + idx]; + out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; + } +} + + +void rms_norm( + torch::Tensor& out, // [num_tokens, hidden_size] + torch::Tensor& input, // [num_tokens, hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { + int num_tokens = input.size(0); + int hidden_size = input.size(1); + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + input.scalar_type(), + "rms_norm_kernel", + [&] { + rms_norm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size); + }); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/layernorm/reduction_utils.cuh b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/layernorm/reduction_utils.cuh new file mode 100644 index 000000000000..2d47c5222084 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/layernorm/reduction_utils.cuh @@ -0,0 +1,50 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + + +template +__inline__ __device__ T warpReduceSum(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(0xffffffff, val, mask, 32); + return val; +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + return val; +} + diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp index 565d134cdedf..16749fd52155 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp @@ -1,3 +1,6 @@ +/*This code from Vllm : https://github.com/vllm-project/vllm + * with minor changes. */ + #include void rotary_embedding_neox( diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu index 1f0f8968619b..8324f1e66556 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu @@ -1,3 +1,6 @@ +/*This code from Vllm : https://github.com/vllm-project/vllm + * with minor changes. */ + #include #include diff --git a/colossalai/kernel/cuda_native/setup.py b/colossalai/kernel/cuda_native/setup.py index ba6f4a6fdbd1..d717c5bb0a69 100644 --- a/colossalai/kernel/cuda_native/setup.py +++ b/colossalai/kernel/cuda_native/setup.py @@ -143,6 +143,22 @@ def append_nvcc_threads(nvcc_extra_args): 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) }, ), + + CUDAExtension( + name="col_rms_norm_ops", + sources=[ + "csrc/attention_infer_kernels/layernorm/layernorm.cpp", + "csrc/attention_infer_kernels/layernorm/layernorm_kernels.cu" + ], + extra_compile_args={ + 'cxx': ['-O3',], + 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) + }, + include_dirs=[ + Path(this_dir)/'csrc'/'attention_infer_kernels'/'layernorm', + ], + ), + CUDAExtension( name="col_flash_attn_2_lib", sources=[ diff --git a/tests/test_kernels/cuda/test_rmsnorm.py b/tests/test_kernels/cuda/test_rmsnorm.py new file mode 100644 index 000000000000..8f748dedefee --- /dev/null +++ b/tests/test_kernels/cuda/test_rmsnorm.py @@ -0,0 +1,57 @@ +import os +import numpy as np + +import torch +from torch import nn +from torch.nn import functional as F +try: + from col_rms_norm_ops import rms_norm + HAS_INFER_CUDA = True +except: + HAS_INFER_CUDA = False + print("please install your cuda ") + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + + +def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon): + x = hidden_states + out = torch.empty_like(x) + rms_norm( + out, + x, + weight, + variance_epsilon, + ) + return out + + +def test_rmsnorm(): + data = torch.randn((1024, 64), dtype=torch.float16, device="cuda") + hg_rms = LlamaRMSNorm(64) + hg_rms = hg_rms.half().cuda() + out_torch = hg_rms(data) + out_cuda = cuda_rmsnorm_forward(data, hg_rms.weight.data, hg_rms.variance_epsilon) + + check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-3) + assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward" + +if __name__ == "__main__": + if HAS_INFER_CUDA: + test_rmsnorm() \ No newline at end of file diff --git a/tests/test_kernels/cuda/test_softmax.py b/tests/test_kernels/cuda/test_softmax.py index 7881c9d3811c..879e8981c85f 100644 --- a/tests/test_kernels/cuda/test_softmax.py +++ b/tests/test_kernels/cuda/test_softmax.py @@ -3,69 +3,25 @@ import torch from torch.nn import functional as F -from col_fused_softmax_lib import scaled_masked_softmax_forward +try: + from col_fused_softmax_lib import scaled_masked_softmax_forward + HAS_INFER_CUDA = True +except: + HAS_INFER_CUDA = False + print("please install your cuda ") -def get_latency_for_cuda(func, data, mask, scale): - starter, ender = torch.cuda.Event( - enable_timing=True), torch.cuda.Event(enable_timing=True) - repetitions = 300 +if HAS_INFER_CUDA: + def test(): + size = (17, 3, 1024, 256) + data = torch.randn(size = size, device="cuda", dtype=torch.float16) + mask = torch.zeros(size = (17, 1, 1024, 256), device="cuda", dtype=torch.uint8) - for i in range(10): - func(data, mask, scale) - - timings = np.zeros((repetitions, 1)) - with torch.no_grad(): - for rep in range(repetitions): - starter.record() - func(data, mask, 1) - ender.record() - # WAIT FOR GPU SYNC - torch.cuda.synchronize() - curr_time = starter.elapsed_time(ender) - timings[rep] = curr_time + out_cuda = scaled_masked_softmax_forward(data, mask, 1) - mean_syn = np.sum(timings) / repetitions - return mean_syn + out_torch = F.softmax(data, dim = -1) + check = torch.allclose(out_cuda.cpu(), out_torch.cpu(), rtol=1e-3, atol=1e-3) + assert check is True, "the output from cuda softmax is not matched with output from torch" -def get_latency_for_torch(func, data): - starter, ender = torch.cuda.Event( - enable_timing=True), torch.cuda.Event(enable_timing=True) - repetitions = 300 - - for i in range(10): - func(data, dim=-1) - - timings = np.zeros((repetitions, 1)) - with torch.no_grad(): - for rep in range(repetitions): - starter.record() - func(data, dim=-1) - ender.record() - # WAIT FOR GPU SYNC - torch.cuda.synchronize() - curr_time = starter.elapsed_time(ender) - timings[rep] = curr_time - - mean_syn = np.sum(timings) / repetitions - return mean_syn - -def test(): - size = (17, 3, 1024, 256) - data = torch.randn(size = size, device="cuda", dtype=torch.float16) - mask = torch.zeros(size = (17, 1, 1024, 256), device="cuda", dtype=torch.uint8) - - out_cuda = scaled_masked_softmax_forward(data, mask, 1) - - out_torch = F.softmax(data, dim = -1) - - torch.allclose(out_cuda.cpu(), out_torch.cpu(), rtol=1e-5, atol=1e-5) - - latency_1 = get_latency_for_cuda(scaled_masked_softmax_forward, data, mask, 1) - latency_2 = get_latency_for_torch(F.softmax, data) - print("the cuda implementation is {} ms".format(str(latency_1))) - print("the original torch cuda implementation is {} ms".format(str(latency_2))) - - -if __name__ == "__main__": - test() \ No newline at end of file + if __name__ == "__main__": + test() \ No newline at end of file From 49bb1e3a18b26a2e2264d994a0503c2407ff0fa6 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 14 Aug 2023 19:02:17 +0800 Subject: [PATCH 09/31] added tests --- .../cuda/test_rotary_embedding.py | 151 ++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 tests/test_kernels/cuda/test_rotary_embedding.py diff --git a/tests/test_kernels/cuda/test_rotary_embedding.py b/tests/test_kernels/cuda/test_rotary_embedding.py new file mode 100644 index 000000000000..178857eae4ad --- /dev/null +++ b/tests/test_kernels/cuda/test_rotary_embedding.py @@ -0,0 +1,151 @@ +from typing import Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, rotate_half + +try: + from col_pos_encoding_ops import rotary_embedding_neox + HAS_INFER_CUDA = True +except: + HAS_INFER_CUDA = False + print("the cuda infer kernels for llama attention is not installed") + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class RefRotaryEmbeddingNeox(nn.Module): + """Reference implementation of the GPT-NeoX style rotary embedding.""" + + def __init__( + self, + dim: int, + max_position_embeddings: int = 2048, + base: int = 10000, + ) -> None: + super().__init__() + self.rotary_dim = dim + self.max_position_embeddings = max_position_embeddings + + # Create cos and sin embeddings. + inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim)) + t = torch.arange(max_position_embeddings).float() + freqs = torch.einsum("i,j->ij", t, inv_freq.float()) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos().to(dtype=inv_freq.dtype) + sin = emb.sin().to(dtype=inv_freq.dtype) + self.register_buffer("cos_cached", cos, persistent=False) + self.register_buffer("sin_cached", sin, persistent=False) + + def forward( + self, + positions: torch.Tensor, # [num_tokens] + query: torch.Tensor, # [num_tokens, num_heads, head_size] + key: torch.Tensor, # [num_tokens, num_heads, head_size] + ) -> Tuple[torch.Tensor, torch.Tensor]: + + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + + query_rot = query_rot.transpose(0, 1) + key_rot = key_rot.transpose(0, 1) + cos = F.embedding(positions, self.cos_cached) + sin = F.embedding(positions, self.sin_cached) + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + query_rot = query_rot.transpose(0, 1).contiguous() + key_rot = key_rot.transpose(0, 1).contiguous() + + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + + # Output query/key shape: [num_tokens, num_tokens, head_size] + return query, key + +def run_rotary_embedding_neox( + num_tokens: int, + num_heads: int, + head_size: int, + max_position: int, + rotary_dim: int, + dtype: torch.dtype, + base: int = 10000, +) -> None: + positions = torch.randint(0, max_position, (num_tokens, ), device='cuda') + query = torch.randn(num_tokens, + num_heads * head_size, + dtype=dtype, + device='cuda') + key = torch.randn(num_tokens, + num_heads * head_size, + dtype=dtype, + device='cuda') + + # Create the rotary embedding. + inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim)) + t = torch.arange(max_position).float() + freqs = torch.einsum('i,j -> ij', t, inv_freq.float()) + cos = freqs.cos() + sin = freqs.sin() + cos_sin_cache = torch.cat((cos, sin), dim=-1) + cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda') + + # Run the kernel. The kernel is in-place, so we need to clone the inputs. + out_query = query.clone() + out_key = key.clone() + rotary_embedding_neox( + positions, + out_query, + out_key, + head_size, + cos_sin_cache, + ) + + # Run the reference implementation. + ref_rotary_embedding = RefRotaryEmbeddingNeox( + dim=rotary_dim, + max_position_embeddings=max_position, + base=base, + ).to(dtype=dtype, device='cuda') + ref_query, ref_key = ref_rotary_embedding( + positions, + query.view(num_tokens, num_heads, head_size), + key.view(num_tokens, num_heads, head_size), + ) + ref_query = ref_query.view(num_tokens, num_heads * head_size) + ref_key = ref_key.view(num_tokens, num_heads * head_size) + + # Compare the results. + assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5) + assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5) + + +def test(): + run_rotary_embedding_neox( + num_tokens=1024, + num_heads=8, + head_size=64, + max_position=8192, + rotary_dim=64, + dtype=torch.float16, + ) + +if __name__ == "__main__": + if HAS_INFER_CUDA: + test() \ No newline at end of file From 0b3cffaa744736acc7fa7497e04008fc22f5a6c0 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 14 Aug 2023 20:40:45 +0800 Subject: [PATCH 10/31] add comments --- .../csrc/attention_infer_kernels/flash_attn/flash_api.cpp | 2 +- .../csrc/attention_infer_kernels/flash_attn/src/block_info.h | 2 +- .../csrc/attention_infer_kernels/flash_attn/src/flash.h | 2 +- .../flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu | 4 +++- .../flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu | 4 +++- .../flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu | 4 +++- .../flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu | 4 +++- .../flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu | 4 +++- .../flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu | 4 +++- .../flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu | 4 +++- .../flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu | 4 +++- .../flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu | 4 +++- .../flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu | 4 +++- .../flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu | 4 +++- .../flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu | 3 +++ .../flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu | 4 +++- .../flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu | 4 +++- .../flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu | 4 +++- .../flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu | 4 +++- .../flash_attn/src/flash_fwd_kernel.h | 2 +- .../flash_attn/src/flash_fwd_launch_template.h | 2 +- .../attention_infer_kernels/flash_attn/src/kernel_traits.h | 2 +- .../flash_attn/src/kernel_traits_sm90.h | 2 +- .../csrc/attention_infer_kernels/flash_attn/src/philox.cuh | 5 ++++- .../csrc/attention_infer_kernels/flash_attn/src/softmax.h | 2 +- .../csrc/attention_infer_kernels/flash_attn/src/utils.h | 2 +- 26 files changed, 61 insertions(+), 25 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/flash_api.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/flash_api.cpp index 07252a3c85bf..9e2847cf5fcc 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/flash_api.cpp +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/flash_api.cpp @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2023, Tri Dao. + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention ******************************************************************************/ #include diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/block_info.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/block_info.h index 94251a41e43b..f6597f1c08da 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/block_info.h +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/block_info.h @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2023, Tri Dao. + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention ******************************************************************************/ #pragma once diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash.h index e65d7d536aa9..17c141b051b6 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash.h +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash.h @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2023, Tri Dao. + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention ******************************************************************************/ #pragma once diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu index 2c8f75b17973..37c2dc95d3ea 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu @@ -1,4 +1,6 @@ -// Copyright (c) 2023, Tri Dao. +/****************************************************************************** + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention + ******************************************************************************/ // Splitting the different head dimensions to different files to speed up compilation. diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu index eca6a06632bd..a855d70e09b7 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu @@ -1,4 +1,6 @@ -// Copyright (c) 2023, Tri Dao. +/****************************************************************************** + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention + ******************************************************************************/ // Splitting the different head dimensions to different files to speed up compilation. diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu index 898cd9c4b6dd..78dc9fa61e1c 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu @@ -1,4 +1,6 @@ -// Copyright (c) 2023, Tri Dao. +/****************************************************************************** + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention + ******************************************************************************/ // Splitting the different head dimensions to different files to speed up compilation. diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu index 19adb4b28deb..105538394349 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu @@ -1,4 +1,6 @@ -// Copyright (c) 2023, Tri Dao. +/****************************************************************************** + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention + ******************************************************************************/ // Splitting the different head dimensions to different files to speed up compilation. diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu index 130bf71d0c5d..bf5ca1d56c68 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu @@ -1,4 +1,6 @@ -// Copyright (c) 2023, Tri Dao. +/****************************************************************************** + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention + ******************************************************************************/ // Splitting the different head dimensions to different files to speed up compilation. diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu index 32cff41a55f0..5862b878c98e 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu @@ -1,4 +1,6 @@ -// Copyright (c) 2023, Tri Dao. +/****************************************************************************** + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention + ******************************************************************************/ // Splitting the different head dimensions to different files to speed up compilation. diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu index 982fe7eadecc..2bffeb1fdfe2 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu @@ -1,4 +1,6 @@ -// Copyright (c) 2023, Tri Dao. +/****************************************************************************** + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention + ******************************************************************************/ // Splitting the different head dimensions to different files to speed up compilation. diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu index 4c083f7b663f..bba585ebba98 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu @@ -1,4 +1,6 @@ -// Copyright (c) 2023, Tri Dao. +/****************************************************************************** + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention + ******************************************************************************/ // Splitting the different head dimensions to different files to speed up compilation. diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu index cb074a95ed8c..3432632cf312 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu @@ -1,4 +1,6 @@ -// Copyright (c) 2023, Tri Dao. +/****************************************************************************** + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention + ******************************************************************************/ // Splitting the different head dimensions to different files to speed up compilation. diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu index ddf5e132293d..f8ddf7d761e8 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu @@ -1,4 +1,6 @@ -// Copyright (c) 2023, Tri Dao. +/****************************************************************************** + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention + ******************************************************************************/ // Splitting the different head dimensions to different files to speed up compilation. diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu index 81e359e16feb..d4800c2932a2 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu @@ -1,4 +1,6 @@ -// Copyright (c) 2023, Tri Dao. +/****************************************************************************** + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention + ******************************************************************************/ // Splitting the different head dimensions to different files to speed up compilation. diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu index b20a1781560e..c997b80c7b68 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu @@ -1,5 +1,8 @@ // Copyright (c) 2023, Tri Dao. +/****************************************************************************** + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention + ******************************************************************************/ // Splitting the different head dimensions to different files to speed up compilation. #include "flash_fwd_launch_template.h" diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu index 12b4552c2073..9177a8f5eebb 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu @@ -1,4 +1,6 @@ -// Copyright (c) 2023, Tri Dao. +/****************************************************************************** + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention + ******************************************************************************/ // Splitting the different head dimensions to different files to speed up compilation. diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu index dd20bc67282b..ab1427e3efd3 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu @@ -1,4 +1,6 @@ -// Copyright (c) 2023, Tri Dao. +/****************************************************************************** + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention + ******************************************************************************/ // Splitting the different head dimensions to different files to speed up compilation. diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu index 7039334a3ae9..399d8e21b85f 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu @@ -1,4 +1,6 @@ -// Copyright (c) 2023, Tri Dao. +/****************************************************************************** + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention + ******************************************************************************/ // Splitting the different head dimensions to different files to speed up compilation. diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu index a8420bd02945..9da8085acbf3 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu @@ -1,4 +1,6 @@ -// Copyright (c) 2023, Tri Dao. +/****************************************************************************** + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention + ******************************************************************************/ // Splitting the different head dimensions to different files to speed up compilation. diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_kernel.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_kernel.h index 7539d71dfc50..6eeca8893e76 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_kernel.h +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_kernel.h @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2023, Tri Dao. + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention ******************************************************************************/ #pragma once diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_launch_template.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_launch_template.h index 1f205961f7b3..14b59b9182b5 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_launch_template.h +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_launch_template.h @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2023, Tri Dao. + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention ******************************************************************************/ #pragma once diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits.h index 3468e4bffc37..4cf445e38545 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits.h +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits.h @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2023, Tri Dao. + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention ******************************************************************************/ #pragma once diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits_sm90.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits_sm90.h index e07f383904a8..c8b61ae2d741 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits_sm90.h +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits_sm90.h @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2023, Tri Dao. + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention ******************************************************************************/ #pragma once diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/philox.cuh b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/philox.cuh index 6ce1440f288d..c54ea0cd14dd 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/philox.cuh +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/philox.cuh @@ -1,4 +1,7 @@ -// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h +/****************************************************************************** + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention + ******************************************************************************/ + #pragma once // Philox CUDA. diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/softmax.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/softmax.h index 3e9a7b4597c6..76200be8c774 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/softmax.h +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/softmax.h @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2023, Tri Dao. + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention ******************************************************************************/ #pragma once diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/utils.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/utils.h index 2221a2faf3a8..a75ae2ca25e8 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/utils.h +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/utils.h @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2023, Tri Dao. + * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention ******************************************************************************/ #pragma once From c171f437172380e301aa4c603cf1d529d221160a Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 15 Aug 2023 17:02:40 +0800 Subject: [PATCH 11/31] refactoring --- .../attention_infer_kernels/linear/gemm.cu | 40 ---- .../linear/linear_op.cpp | 206 ------------------ .../{layernorm => rmsnorm}/layernorm.cpp | 0 .../layernorm_kernels.cu | 0 .../reduction_utils.cuh | 0 colossalai/kernel/cuda_native/setup.py | 39 +--- 6 files changed, 3 insertions(+), 282 deletions(-) delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear/gemm.cu delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear/linear_op.cpp rename colossalai/kernel/cuda_native/csrc/attention_infer_kernels/{layernorm => rmsnorm}/layernorm.cpp (100%) rename colossalai/kernel/cuda_native/csrc/attention_infer_kernels/{layernorm => rmsnorm}/layernorm_kernels.cu (100%) rename colossalai/kernel/cuda_native/csrc/attention_infer_kernels/{layernorm => rmsnorm}/reduction_utils.cuh (100%) diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear/gemm.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear/gemm.cu deleted file mode 100644 index a0c42bdb05fb..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear/gemm.cu +++ /dev/null @@ -1,40 +0,0 @@ -#include -#include -#include - -#include -#include - -void dense_layer_fp32_kernel(const float *in, const float *weight, float *out, const int M, - const int K, const int N, cublasHandle_t cublas_handle, - cudaStream_t stream, int cublasAlgo) { - const float alpha = 1.0f, beta = 0.0f; - cublasGemmEx(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K, &alpha, weight, - CUDA_R_32F, N, in, CUDA_R_32F, K, &beta, out, CUDA_R_32F, N, - CUDA_R_32F, static_cast(cublasAlgo)); -} - -void dense_layer_fp16_kernel(const __half *in, const __half *weight, __half *out, const int M, - const int K, const int N, cublasHandle_t cublas_handle, - cudaStream_t stream, int cublasAlgo) { - const __half alpha = (__half)1.0f, beta = (__half)0.0f; - cublasGemmEx(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K, &alpha, weight, - CUDA_R_16F, N, in, CUDA_R_16F, K, &beta, out, CUDA_R_16F, N, - CUDA_R_16F, static_cast(cublasAlgo)); -} - - -void cublas_Gemm_Strided_Batched_FP16_Kernel(const __half *A, const __half *B, __half *out, const int M, - const int K, const int N, const int batch_count, - cublasOperation_t trans_A, cublasOperation_t trans_B, - __half alpha, __half beta, cublasHandle_t cublas_handle, - cudaStream_t stream, int cublasAlgo) { - const int lda = (trans_A == CUBLAS_OP_N) ? K : M; - const int ldb = (trans_B == CUBLAS_OP_N) ? N : K; - - - cublasGemmStridedBatchedEx( - cublas_handle, trans_B, trans_A, N, M, K, &alpha, B, CUDA_R_16F, ldb, K * N, A, CUDA_R_16F, - lda, M * K, &beta, out, CUDA_R_16F, N, M * N, batch_count, CUDA_R_16F, - static_cast(cublasAlgo)); -} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear/linear_op.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear/linear_op.cpp deleted file mode 100644 index dc4f50ebfa55..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/linear/linear_op.cpp +++ /dev/null @@ -1,206 +0,0 @@ -#include -#include -#include -#include -#include - - -class CublasHandle -{ -public: - static CublasHandle& instance() - { - static CublasHandle handle; - return handle; - } - - cublasHandle_t get() const - { - return handle; - } - - CublasHandle(CublasHandle const&) = delete; - void operator=(CublasHandle const&) = delete; - -private: - cublasHandle_t handle; - - CublasHandle() - { - cublasStatus_t stat = cublasCreate(&handle); - if (stat != CUBLAS_STATUS_SUCCESS) - { - printf("cuBLAS initialization error: %d\n", stat); - exit(stat); - } - } - - ~CublasHandle() - { - cublasDestroy(handle); - } -}; - -class CudaStream { -public: - // Get the singleton instance - static CudaStream& instance() { - static CudaStream instance; - return instance; - } - - // Get the cudaStream_t - cudaStream_t get() const { - return stream; - } - -private: - // The cudaStream_t object - cudaStream_t stream; - - // Private constructor and destructor - CudaStream() { - cudaError_t err = cudaStreamCreate(&stream); - if (err != cudaSuccess) { - printf("cuda stream initialization error"); - exit(-1); - } - } - - ~CudaStream() { - cudaStreamDestroy(stream); - } - - // Delete copy and assignment constructors - CudaStream(const CudaStream&) = delete; - CudaStream& operator=(const CudaStream&) = delete; -}; - - - - -void dense_layer_fp32_kernel(const float *in, const float *weight, float *out, const int M, - const int K, const int N, cublasHandle_t cublas_handle, - cudaStream_t stream, int cublasAlgo = -1); - -void dense_layer_fp16_kernel(const __half *in, const __half *weight, __half *out, const int M, - const int K, const int N, cublasHandle_t cublas_handle, - cudaStream_t stream, int cublasAlgo = 99); - - -void cublas_Gemm_Strided_Batched_FP16_Kernel(const __half *A, const __half *B, __half *out, const int M, - const int K, const int N, const int batch_count, - cublasOperation_t trans_A, cublasOperation_t trans_B, - __half alpha, __half beta, cublasHandle_t cublas_handle, - cudaStream_t stream, int cublasAlgo = 99); - - -void dense_layer_fp32_forward(torch::Tensor& in, torch::Tensor& weight, torch::Tensor& out, int cublasAlgo) { - const int M = in.size(0); - const int K = in.size(1); - const int N = weight.size(1); - // Assumes in and weight are CUDA tensors, hence can call .data_ptr. - - cublasHandle_t handle = CublasHandle::instance().get(); - - // Now you can get a cudaStream_t like this: - cudaStream_t stream = CudaStream::instance().get(); - - dense_layer_fp32_kernel(in.data_ptr(), weight.data_ptr(), out.data_ptr(), M, K, N, handle, stream, cublasAlgo); - -} - - - -void dense_layer_fp16_forward(torch::Tensor& in, torch::Tensor& weight, torch::Tensor& out, int cublasAlgo = 99) { - const int M = in.size(0); - const int K = in.size(1); - const int N = weight.size(1); - - cublasHandle_t handle = CublasHandle::instance().get(); - - // Now you can get a cudaStream_t like this: - cudaStream_t stream = CudaStream::instance().get(); - - if(in.is_contiguous() == false){ - in = in.contiguous(); - } - - if(weight.is_contiguous() == false) { - weight = weight.contiguous(); - } - - if(out.is_contiguous() == false) { - out = out.contiguous(); - } - - dense_layer_fp16_kernel(reinterpret_cast(in.data_ptr()), - reinterpret_cast(weight.data_ptr()), - reinterpret_cast<__half*>(out.data_ptr()), - M, K, N, handle, stream, cublasAlgo); - - - -} - -void batch_dense_layer_fp16_forward(torch::Tensor& in, torch::Tensor& weight, torch::Tensor& out, float alpha, float beta, bool weight_transpose = false, int cublasAlgo = 99) { - const int batch_count = in.size(0); - const int M = in.size(1); - const int K = in.size(2); - int N = weight.size(2); - if(weight_transpose) { - N = weight.size(1); - } - - cublasHandle_t handle = CublasHandle::instance().get(); - - // Now you can get a cudaStream_t like this: - cudaStream_t stream = CudaStream::instance().get(); - - // if(in.is_contiguous() == false){ - // in = in.contiguous(); - // } - - // if(weight.is_contiguous() == false) { - // weight = weight.contiguous(); - // } - - // if(out.is_contiguous() == false) { - // out = out.contiguous(); - // } - if(weight_transpose == false) { - cublas_Gemm_Strided_Batched_FP16_Kernel(reinterpret_cast(in.data_ptr()), - reinterpret_cast(weight.data_ptr()), - reinterpret_cast<__half*>(out.data_ptr()), - M, K, N, batch_count, - CUBLAS_OP_N, CUBLAS_OP_N, - (__half)alpha, (__half)beta, handle, stream, cublasAlgo - ); - }else { - cublas_Gemm_Strided_Batched_FP16_Kernel(reinterpret_cast(in.data_ptr()), - reinterpret_cast(weight.data_ptr()), - reinterpret_cast<__half*>(out.data_ptr()), - M, K, N, batch_count, - CUBLAS_OP_N, CUBLAS_OP_T, - (__half)alpha, (__half)beta, handle, stream, cublasAlgo - ); - } -} - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("dense_layer_fp32_forward", - &dense_layer_fp32_forward, - "fp32 forward of dense layer"); - - m.def("dense_layer_fp16_forward", - &dense_layer_fp16_forward, - "fp16 forward of dense layer." - ); - - m.def("batch_dense_layer_fp16_forward", - &batch_dense_layer_fp16_forward, - "fp16 forward of batch gemm" - ); - -} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/layernorm/layernorm.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/attention_infer_kernels/layernorm/layernorm.cpp rename to colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm.cpp diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/layernorm/layernorm_kernels.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/attention_infer_kernels/layernorm/layernorm_kernels.cu rename to colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/layernorm/reduction_utils.cuh b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/reduction_utils.cuh similarity index 100% rename from colossalai/kernel/cuda_native/csrc/attention_infer_kernels/layernorm/reduction_utils.cuh rename to colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/reduction_utils.cuh diff --git a/colossalai/kernel/cuda_native/setup.py b/colossalai/kernel/cuda_native/setup.py index d717c5bb0a69..3665327769bd 100644 --- a/colossalai/kernel/cuda_native/setup.py +++ b/colossalai/kernel/cuda_native/setup.py @@ -118,7 +118,7 @@ def append_nvcc_threads(nvcc_extra_args): cc_flag.append("arch=compute_90,code=sm_90") setup( - name='colossal-cuda-kernels', + name='colossal-cuda-infer-kernels', ext_modules=[ CUDAExtension( name='col_fused_softmax_lib', @@ -147,8 +147,8 @@ def append_nvcc_threads(nvcc_extra_args): CUDAExtension( name="col_rms_norm_ops", sources=[ - "csrc/attention_infer_kernels/layernorm/layernorm.cpp", - "csrc/attention_infer_kernels/layernorm/layernorm_kernels.cu" + "csrc/attention_infer_kernels/rmsnorm/layernorm.cpp", + "csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu" ], extra_compile_args={ 'cxx': ['-O3',], @@ -207,39 +207,6 @@ def append_nvcc_threads(nvcc_extra_args): ], ), - CUDAExtension( - name="col_linear_lib", - sources=[ - "csrc/attention_infer_kernels/linear/linear_op.cpp", - "csrc/attention_infer_kernels/linear/gemm.cu", - ], - extra_compile_args={ - "cxx": ["-O3", "-std=c++17"] + generator_flag, - "nvcc": append_nvcc_threads( - [ - "-O3", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "--ptxas-options=-v", - "-lineinfo" - "-lcublas" - ] - + generator_flag - + cc_flag - ), - }, - include_dirs=[ - Path(this_dir) / 'csrc'/'attention_infer_kernels' /'linear' , - Path(this_dir) / 'csrc'/'cutlass' / 'include', - ], - ), - ], cmdclass={ 'build_ext': BuildExtension From 0e0594e30aa7255c6fb3f5fe809999564ae6f681 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 15 Aug 2023 17:22:59 +0800 Subject: [PATCH 12/31] fix tests --- colossalai/kernel/cuda_native/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/kernel/cuda_native/setup.py b/colossalai/kernel/cuda_native/setup.py index 3665327769bd..1be6690b46ac 100644 --- a/colossalai/kernel/cuda_native/setup.py +++ b/colossalai/kernel/cuda_native/setup.py @@ -155,7 +155,7 @@ def append_nvcc_threads(nvcc_extra_args): 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) }, include_dirs=[ - Path(this_dir)/'csrc'/'attention_infer_kernels'/'layernorm', + Path(this_dir)/'csrc'/'attention_infer_kernels'/'rmsnorm', ], ), From 9393dd18d870a6b2208f89884e2787f1a352d213 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 15 Aug 2023 17:47:59 +0800 Subject: [PATCH 13/31] change flash-attention as thrid-party directly --- .../flash_attn/flash_api.cpp | 420 ------------- .../flash_attn/src/block_info.h | 41 -- .../flash_attn/src/flash.h | 144 ----- .../src/flash_fwd_hdim128_bf16_sm80.cu | 12 - .../src/flash_fwd_hdim128_fp16_sm80.cu | 13 - .../src/flash_fwd_hdim160_bf16_sm80.cu | 12 - .../src/flash_fwd_hdim160_fp16_sm80.cu | 12 - .../src/flash_fwd_hdim192_bf16_sm80.cu | 11 - .../src/flash_fwd_hdim192_fp16_sm80.cu | 12 - .../src/flash_fwd_hdim224_bf16_sm80.cu | 11 - .../src/flash_fwd_hdim224_fp16_sm80.cu | 11 - .../src/flash_fwd_hdim256_bf16_sm80.cu | 11 - .../src/flash_fwd_hdim256_fp16_sm80.cu | 11 - .../src/flash_fwd_hdim32_bf16_sm80.cu | 12 - .../src/flash_fwd_hdim32_fp16_sm80.cu | 13 - .../src/flash_fwd_hdim64_bf16_sm80.cu | 12 - .../src/flash_fwd_hdim64_fp16_sm80.cu | 12 - .../src/flash_fwd_hdim96_bf16_sm80.cu | 12 - .../src/flash_fwd_hdim96_fp16_sm80.cu | 11 - .../flash_attn/src/flash_fwd_kernel.h | 572 ------------------ .../src/flash_fwd_launch_template.h | 227 ------- .../flash_attn/src/kernel_traits.h | 366 ----------- .../flash_attn/src/kernel_traits_sm90.h | 159 ----- .../flash_attn/src/philox.cuh | 168 ----- .../flash_attn/src/softmax.h | 272 --------- .../flash_attn/src/static_switch.h | 66 -- .../flash_attn/src/utils.h | 388 ------------ colossalai/kernel/cuda_native/setup.py | 48 -- 28 files changed, 3059 deletions(-) delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/flash_api.cpp delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/block_info.h delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash.h delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_kernel.h delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_launch_template.h delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits.h delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits_sm90.h delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/philox.cuh delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/softmax.h delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/static_switch.h delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/utils.h diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/flash_api.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/flash_api.cpp deleted file mode 100644 index 9e2847cf5fcc..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/flash_api.cpp +++ /dev/null @@ -1,420 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -#include -#include -#include - -#include - -#include "flash.h" -#include "static_switch.h" - -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") - - -void set_params_fprop(Flash_fwd_params ¶ms, - // sizes - const size_t b, - const size_t seqlen_q, - const size_t seqlen_k, - const size_t seqlen_q_rounded, - const size_t seqlen_k_rounded, - const size_t h, - const size_t h_k, - const size_t d, - const size_t d_rounded, - // device pointers - const at::Tensor q, - const at::Tensor k, - const at::Tensor v, - at::Tensor out, - void *cu_seqlens_q_d, - void *cu_seqlens_k_d, - void *p_d, - void *softmax_lse_d, - float p_dropout, - float softmax_scale, - bool is_causal) { - - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - - params.is_bf16 = q.dtype() == torch::kBFloat16; - - // Set the pointers and strides. - params.q_ptr = q.data_ptr(); - params.k_ptr = k.data_ptr(); - params.v_ptr = v.data_ptr(); - // All stride are in elements, not bytes. - params.q_row_stride = q.stride(-3); - params.k_row_stride = k.stride(-3); - params.v_row_stride = v.stride(-3); - params.q_head_stride = q.stride(-2); - params.k_head_stride = k.stride(-2); - params.v_head_stride = v.stride(-2); - params.o_ptr = out.data_ptr(); - params.o_row_stride = out.stride(-3); - params.o_head_stride = out.stride(-2); - - if (cu_seqlens_q_d == nullptr) { - params.q_batch_stride = q.stride(0); - params.k_batch_stride = k.stride(0); - params.v_batch_stride = v.stride(0); - params.o_batch_stride = out.stride(0); - } - - params.cu_seqlens_q = static_cast(cu_seqlens_q_d); - params.cu_seqlens_k = static_cast(cu_seqlens_k_d); - - // P = softmax(QK^T) - params.p_ptr = p_d; - - // Softmax sum - params.softmax_lse_ptr = softmax_lse_d; - - // Set the dimensions. - params.b = b; - params.h = h; - params.h_k = h_k; - params.h_h_k_ratio = h / h_k; - params.seqlen_q = seqlen_q; - params.seqlen_k = seqlen_k; - params.seqlen_q_rounded = seqlen_q_rounded; - params.seqlen_k_rounded = seqlen_k_rounded; - params.d = d; - params.d_rounded = d_rounded; - - // Set the different scale values. - params.scale_softmax = softmax_scale; - params.scale_softmax_log2 = softmax_scale * M_LOG2E; - - // Set this to probability of keeping an element to simplify things. - params.p_dropout = 1.f - p_dropout; - // Convert p from float to int so we don't have to convert the random uint to float to compare. - // [Minor] We want to round down since when we do the comparison we use <= instead of < - // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); - // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); - params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); - params.rp_dropout = 1.f / params.p_dropout; - params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; - TORCH_CHECK(p_dropout < 1.f); - - params.is_causal = is_causal; -} - -void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - FP16_SWITCH(!params.is_bf16, [&] { - FWD_HEADDIM_SWITCH(params.d, [&] { - run_mha_fwd_(params, stream); - }); - }); -} - -std::vector -mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size - const float p_dropout, - const float softmax_scale, - const bool is_causal, - const bool return_softmax, - c10::optional gen_) { - - auto dprops = at::cuda::getCurrentDeviceProperties(); - // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); - // We will support Turing in the near future - // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); - - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, - "FlashAttention only support fp16 and bf16 data type"); - if (q_dtype == torch::kBFloat16) { - TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); - } - TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - - TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device"); - TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device"); - TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device"); - - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - - const auto sizes = q.sizes(); - - const int batch_size = sizes[0]; - const int seqlen_q = sizes[1]; - const int num_heads = sizes[2]; - const int head_size_og = sizes[3]; - const int seqlen_k = k.size(1); - const int num_heads_k = k.size(2); - TORCH_CHECK(batch_size > 0, "batch size must be postive"); - TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); - CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); - CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); - - at::Tensor q_padded, k_padded, v_padded; - if (head_size_og % 8 != 0) { - q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - } else { - q_padded = q; - k_padded = k; - v_padded = v; - } - - at::Tensor out; - if (out_.has_value()) { - out = out_.value(); - TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); - TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device"); - TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); - if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } - } else { - out = torch::empty_like(q_padded); - } - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size = round_multiple(head_size_og, 8); - const int head_size_rounded = round_multiple(head_size, 32); - const int seqlen_q_rounded = round_multiple(seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)q.get_device()}; - - auto opts = q.options(); - - auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); - at::Tensor p; - // Only return softmax if there's dropout to reduce compilation time - if (return_softmax) { - TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); - p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); - } - - Flash_fwd_params params; - set_params_fprop(params, - batch_size, - seqlen_q, seqlen_k, - seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, - head_size, head_size_rounded, - q_padded, k_padded, v_padded, out, - /*cu_seqlens_q_d=*/nullptr, - /*cu_seqlens_k_d=*/nullptr, - return_softmax ? p.data_ptr() : nullptr, - softmax_lse.data_ptr(), - p_dropout, - softmax_scale, - is_causal); - - // number of times random will be generated per thread, to offset philox counter in thc random - // state - // We use a custom RNG that increases the offset by batch_size * nheads * 32. - int64_t counter_offset = params.b * params.h * 32; - auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - // Forward kernel will populate memory with the seed and offset. - params.rng_state = reinterpret_cast(rng_state.data_ptr()); - - if (p_dropout > 0.0) { - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - params.philox_args = gen->philox_cuda_state(counter_offset); - } - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - run_mha_fwd(params, stream); - - at::Tensor out_padded = out; - if (head_size_og % 8 != 0) { - out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); - if (out_.has_value()) { out_.value().copy_(out); } - } - - return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; -} - -std::vector -mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &cu_seqlens_q, // b+1 - const at::Tensor &cu_seqlens_k, // b+1 - const int max_seqlen_q, - const int max_seqlen_k, - const float p_dropout, - const float softmax_scale, - const bool zero_tensors, - const bool is_causal, - const bool return_softmax, - c10::optional gen_) { - - auto dprops = at::cuda::getCurrentDeviceProperties(); - // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); - // We will support Turing in the near future - // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); - - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, - "FlashAttention only support fp16 and bf16 data type"); - if (q_dtype == torch::kBFloat16) { - TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); - } - TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); - - TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device"); - TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device"); - TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device"); - TORCH_CHECK(cu_seqlens_q.is_cuda(), "cu_seqlens_q must be on CUDA device"); - TORCH_CHECK(cu_seqlens_k.is_cuda(), "cu_seqlens_k must be on CUDA device"); - - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(cu_seqlens_q.is_contiguous(), "cu_seqlens_q must be contiguous"); - TORCH_CHECK(cu_seqlens_k.is_contiguous(), "cu_seqlens_k must be contiguous"); - - const auto sizes = q.sizes(); - - const int total_q = sizes[0]; - const int batch_size = cu_seqlens_q.numel() - 1; - const int num_heads = sizes[1]; - const int head_size_og = sizes[2]; - const int total_k = k.size(0); - const int num_heads_k = k.size(1); - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - CHECK_SHAPE(q, total_q, num_heads, head_size_og); - CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); - CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); - CHECK_SHAPE(cu_seqlens_q, batch_size + 1); - CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - - at::Tensor q_padded, k_padded, v_padded; - if (head_size_og % 8 != 0) { - q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - } else { - q_padded = q; - k_padded = k; - v_padded = v; - } - - at::Tensor out; - if (out_.has_value()) { - out = out_.value(); - TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); - TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device"); - TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, total_q, num_heads, head_size_og); - if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } - } else { - out = torch::empty_like(q_padded); - } - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size = round_multiple(head_size_og, 8); - const int head_size_rounded = round_multiple(head_size, 32); - const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)q.get_device()}; - - auto opts = q.options(); - - auto softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); - at::Tensor p; - // Only return softmax if there's dropout to reduce compilation time - if (return_softmax) { - TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); - p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); - } - - if (zero_tensors) { - out.zero_(); - softmax_lse.fill_(-std::numeric_limits::infinity()); - if (return_softmax) {p.zero_();} - } - - Flash_fwd_params params; - set_params_fprop(params, - batch_size, - max_seqlen_q, max_seqlen_k, - seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, - head_size, head_size_rounded, - q_padded, k_padded, v_padded, out, - cu_seqlens_q.data_ptr(), - cu_seqlens_k.data_ptr(), - return_softmax ? p.data_ptr() : nullptr, - softmax_lse.data_ptr(), - p_dropout, - softmax_scale, - is_causal); - - // number of times random will be generated per thread, to offset philox counter in thc random - // state - // We use a custom RNG that increases the offset by batch_size * nheads * 32. - int64_t counter_offset = params.b * params.h * 32; - auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - // Forward kernel will populate memory with the seed and offset. - params.rng_state = reinterpret_cast(rng_state.data_ptr()); - - if (p_dropout > 0.0) { - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - params.philox_args = gen->philox_cuda_state(counter_offset); - } - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - run_mha_fwd(params, stream); - - at::Tensor out_padded = out; - if (head_size_og % 8 != 0) { - out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); - if (out_.has_value()) { out_.value().copy_(out); } - } - - return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; -} - - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "FlashAttention"; - m.def("flash_fwd", &mha_fwd, "Forward pass"); - m.def("varlen_flash_fwd", &mha_varlen_fwd, "Forward pass (variable length)"); -} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/block_info.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/block_info.h deleted file mode 100644 index f6597f1c08da..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/block_info.h +++ /dev/null @@ -1,41 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -#pragma once - -namespace flash { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BlockInfo { - - template - __device__ BlockInfo(const Params ¶ms, const int bidb) - : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) - , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]) - , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) - , actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) - { - } - - template - inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; - } - - template - inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; - } - - const int sum_s_q; - const int sum_s_k; - const uint32_t actual_seqlen_q; - const uint32_t actual_seqlen_k; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace flash diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash.h deleted file mode 100644 index 17c141b051b6..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash.h +++ /dev/null @@ -1,144 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -#pragma once - -#include -#include - -#ifdef OLD_GENERATOR_PATH -#include -#else -#include -#endif - -#include - - -constexpr int TOTAL_DIM = 0; -constexpr int H_DIM = 1; -constexpr int D_DIM = 2; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Qkv_params { - using index_t = uint32_t; - // The QKV matrices. - void *__restrict__ q_ptr; - void *__restrict__ k_ptr; - void *__restrict__ v_ptr; - - // The stride between rows of the Q, K and V matrices. - index_t q_batch_stride; - index_t k_batch_stride; - index_t v_batch_stride; - index_t q_row_stride; - index_t k_row_stride; - index_t v_row_stride; - index_t q_head_stride; - index_t k_head_stride; - index_t v_head_stride; - - // The number of heads. - int h, h_k; - // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be - // different from nheads (query). - int h_h_k_ratio; // precompute h / h_k, -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Flash_fwd_params : public Qkv_params { - - // The O matrix (output). - void * __restrict__ o_ptr; - - // The stride between rows of O. - index_t o_batch_stride; - index_t o_row_stride; - index_t o_head_stride; - - // The pointer to the P matrix. - void * __restrict__ p_ptr; - - // The pointer to the softmax sum. - void * __restrict__ softmax_lse_ptr; - - // The dimensions. - int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; - - // The scaling factors for the kernel. - float scale_softmax; - float scale_softmax_log2; - - // array of length b+1 holding starting offset of each sequence. - int * __restrict__ cu_seqlens_q; - int * __restrict__ cu_seqlens_k; - - int *__restrict__ blockmask; - - // The dropout probability (probability of keeping an activation). - float p_dropout; - // uint32_t p_dropout_in_uint; - // uint16_t p_dropout_in_uint16_t; - uint8_t p_dropout_in_uint8_t; - - // Scale factor of 1 / (1 - p_dropout). - float rp_dropout; - float scale_softmax_rp_dropout; - - // Random state. - at::PhiloxCudaState philox_args; - - // Pointer to the RNG seed (idx 0) and offset (idx 1). - uint64_t * rng_state; - - bool is_bf16; - bool is_causal; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Flash_bwd_params : public Flash_fwd_params { - - // The dO and dQKV matrices. - void *__restrict__ do_ptr; - void *__restrict__ dq_ptr; - void *__restrict__ dk_ptr; - void *__restrict__ dv_ptr; - - // To accumulate dQ - void *__restrict__ dq_accum_ptr; - void *__restrict__ dk_accum_ptr; - void *__restrict__ dv_accum_ptr; - - // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q - // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ - // dv_accum_ptr; - - // The stride between rows of the dO, dQ, dK and dV matrices. - // TD [2022-04-16]: We're using 32-bit indexing to save registers. - // The code probably won't work for arrays larger than 2GB. - index_t do_batch_stride; - index_t do_row_stride; - index_t do_head_stride; - index_t dq_batch_stride; - index_t dk_batch_stride; - index_t dv_batch_stride; - index_t dq_row_stride; - index_t dk_row_stride; - index_t dv_row_stride; - index_t dq_head_stride; - index_t dk_head_stride; - index_t dv_head_stride; - - // The pointer to the softmax d sum. - void *__restrict__ dsoftmax_sum; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); - -template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure); diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu deleted file mode 100644 index 37c2dc95d3ea..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim128(params, stream); -} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu deleted file mode 100644 index a855d70e09b7..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu +++ /dev/null @@ -1,13 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_launch_template.h" - - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim128(params, stream); -} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu deleted file mode 100644 index 78dc9fa61e1c..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim160(params, stream); -} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu deleted file mode 100644 index 105538394349..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim160(params, stream); -} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu deleted file mode 100644 index bf5ca1d56c68..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu +++ /dev/null @@ -1,11 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_launch_template.h" - -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim192(params, stream); -} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu deleted file mode 100644 index 5862b878c98e..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim192(params, stream); -} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu deleted file mode 100644 index 2bffeb1fdfe2..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu +++ /dev/null @@ -1,11 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_launch_template.h" - -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim224(params, stream); -} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu deleted file mode 100644 index bba585ebba98..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu +++ /dev/null @@ -1,11 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_launch_template.h" - -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim224(params, stream); -} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu deleted file mode 100644 index 3432632cf312..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu +++ /dev/null @@ -1,11 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_launch_template.h" - -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim256(params, stream); -} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu deleted file mode 100644 index f8ddf7d761e8..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu +++ /dev/null @@ -1,11 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_launch_template.h" - -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim256(params, stream); -} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu deleted file mode 100644 index d4800c2932a2..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim32(params, stream); -} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu deleted file mode 100644 index c997b80c7b68..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) 2023, Tri Dao. - -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim32(params, stream); -} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu deleted file mode 100644 index 9177a8f5eebb..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim64(params, stream); -} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu deleted file mode 100644 index ab1427e3efd3..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim64(params, stream); -} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu deleted file mode 100644 index 399d8e21b85f..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu +++ /dev/null @@ -1,12 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_launch_template.h" - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim96(params, stream); -} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu deleted file mode 100644 index 9da8085acbf3..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu +++ /dev/null @@ -1,11 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -// Splitting the different head dimensions to different files to speed up compilation. - -#include "flash_fwd_launch_template.h" - -template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim96(params, stream); -} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_kernel.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_kernel.h deleted file mode 100644 index 6eeca8893e76..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_kernel.h +++ /dev/null @@ -1,572 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -#pragma once - -#include -#include -#include - -#include -#include -#include -#include - -#include "block_info.h" -#include "kernel_traits.h" -#include "utils.h" -#include "softmax.h" -#include "philox.cuh" - -namespace flash { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTE_HOST_DEVICE -auto -make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(Layout, Int>, - Stride<_1, Int> >{}, - make_layout(size<2>(TileShape_MNK{}))); - // if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); } - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTE_HOST_DEVICE -auto -make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(Layout, Int>, - Stride<_1, Int> >{}, - // TODO: Shouldn't this be size<1>? - make_layout(size<2>(TileShape_MNK{}))); - // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); } - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, - Tensor2 &acc_o, float softmax_scale_log2) { - if (Is_first) { - flash::template reduce_max(scores, scores_max); - flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); - flash::reduce_sum(scores, scores_sum); - } else { - Tensor scores_max_prev = make_fragment_like(scores_max); - copy(scores_max, scores_max_prev); - flash::template reduce_max(scores, scores_max); - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); - #pragma unroll - for (int mi = 0; mi < size(scores_max); ++mi) { - float scores_max_cur = !Check_inf - ? scores_max(mi) - : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); - float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); - scores_sum(mi) *= scores_scale; - #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } - } - flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); - Tensor scores_sum_cur = make_fragment_like(scores_sum); - flash::reduce_sum(scores, scores_sum_cur); - #pragma unroll - for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void write_softmax_to_gmem( - Tensor const &tOrP, Tensor &tPgP, TiledCopy gmem_thr_copy_P -) { - // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) - Layout l = tOrP.layout(); - Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); - CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{}); - CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP)); - #pragma unroll - for (int mi = 0; mi < size<1>(tPrP); ++mi) { - copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { - - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Shared memory. - extern __shared__ char smem_[]; - - // The thread index. - const int tidx = threadIdx.x; - // The global block index. - const int block_id = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; - - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - constexpr int kNWarps = Kernel_traits::kNWarps; - constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; - - const BlockInfo binfo(params, bidb); - if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; - - int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); - if (Is_causal) { - n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { - // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); - // } - } - - // We iterate over the blocks in reverse order. This is because the last block is the only one - // that needs masking when we read K and V from global memory. Moreover, iterating in reverse - // might save us 1 register (we just need n_block instead of both n_block and n_block_max). - - const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) - + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; - // We move K and V to the last block. - const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) - + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) - + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded - + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; - - Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - Shape, Int>{}, - make_stride(params.q_row_stride, _1{})); - Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), - Shape, Int>{}, - make_stride(params.k_row_stride, _1{})); - Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), - Shape, Int>{}, - make_stride(params.v_row_stride, _1{})); - Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), - Shape, Int>{}, - make_stride(params.seqlen_k_rounded, _1{})); - - Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutQ{}); - // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; - Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), - typename Kernel_traits::SmemLayoutKV{}); - Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); - Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); - Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); - - auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx); - auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx); - - Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - Tensor tPgP = gmem_thr_copy_P.partition_D(gP); - - typename Kernel_traits::TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_thread_slice(tidx); - Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) - Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) - - Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K - - // - // Copy Atom retiling - // - - auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); - // auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); - // if (cute::thread0()) {smem_thr_copy_Q.print_all();} - Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); - // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} - - auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); - Tensor tSsK = smem_thr_copy_K.partition_S(sK); - - auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx); - Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - - // TODO: this might need to change if we change the mma instruction in SM70 - Tensor scores_max = make_tensor(Shape(acc_o)>>{}); - Tensor scores_sum = make_fragment_like(scores_max); - - // - // PREDICATES - // - - // // Allocate predicate tensors for m and n - // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); - // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); - - // Construct identity layout for sQ and sK - Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) - - // Repeat the partitioning with identity layouts - Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) - - // Allocate predicate tensors for k - Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); - Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); - - // Set predicates for k bounds - if (!Is_even_K) { - #pragma unroll - for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } - #pragma unroll - for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } - } - - // Prologue - - Tensor tQrQ = make_fragment_like(tQgQ); - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, - binfo.actual_seqlen_q - m_block * kBlockM); - if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } - - // // Copy rmem to smem - // // copy(tQrQ, tQsQ); - // flash::cp_async_wait<0>(); - // __syncthreads(); - // // if (cute::thread(1, 0)) { print(tQsQ); } - // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{}); - // // if (cute::thread0()) { print(sQNoSwizzle); } - - if (Kernel_traits::Share_Q_K_smem) { - flash::cp_async_wait<0>(); - __syncthreads(); - Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); - CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M - copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); - __syncthreads(); - } - - int n_block = n_block_max - 1; - // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN); - cute::cp_async_fence(); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } - // __syncthreads(); - - if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { - flash::cp_async_wait<1>(); - __syncthreads(); - Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); - CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M - copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); - } - - auto seeds = at::cuda::philox::unpack(params.philox_args); - unsigned long long seed = std::get<0>(seeds); - unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; - - // Save seed and offset for backward. - if (block_id == 0 && tidx == 0) { - params.rng_state[0] = seed; - params.rng_state[1] = std::get<1>(seeds); - } - - clear(acc_o); - - // For performance reason, we separate out two kinds of iterations: - // those that need masking on S, and those that don't. - // We need masking on S for the very last block when K and V has length not multiple of kBlockN. - // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. - // We will have at least 1 "masking" iteration. - - constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1; - #pragma unroll - for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { - Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); - flash::cp_async_wait<0>(); - __syncthreads(); - - // Advance gV - if (masking_step > 0) { - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); - } else { - // Clear the smem tiles to account for predicated off loads - flash::copy( - gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN - ); - } - cute::cp_async_fence(); - - flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K - ); - // if (cute::thread0()) { print(acc_s); } - - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - // if (cute::thread0()) { print(scores); } - // We don't put the masking before the matmul S = Q K^T because we don't clear sK - // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul - // can produce Inf / NaN. - if (!Is_causal) { - if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } - } else { - // Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) - // Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N) - // static_assert(decltype(size<0>(taccScS))::value == 4); - // // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices. - // Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0); - // Tensor idx_rowcol = make_tensor(taccScS.data(), flash::convert_layout_acc_rowcol(taccScS.layout())); - // flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM); - // Idk why it's get<1> and not get<0> of the stride. - // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } - // I can't get the stride from idx_row - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM + get<0>(idx_row(0)), - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - kNWarps * 16); - // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); - // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); - } - - flash::cp_async_wait<0>(); - __syncthreads(); - if (n_block > 0) { - // Advance gK - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); - // This cp_async_fence needs to be in the if block, otherwise the synchronization - // isn't right and we get race conditions. - cute::cp_async_fence(); - } - - // TODO: when we have key_padding_mask we'll need to Check_inf - masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); - - // Convert scores from fp32 to fp16/bf16 - Tensor rP = flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; - uint32_t block_col_idx = n_block * (kBlockN / 32); - if (Return_softmax) { - Tensor tOrP_copy = make_fragment_like(tOrP); - copy(tOrP, tOrP_copy); - flash::apply_dropout( - tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, - block_row_idx, block_col_idx, kNWarps - ); - flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); - tPgP.data() = tPgP.data() + (-kBlockN); - } - if (Is_dropout) { - flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, - block_row_idx, block_col_idx, kNWarps); - } - // if (cute::thread0()) { print(tOrP); } - - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); - // if (cute::thread0()) { print(scores); } - - // This check is at the end of the loop since we always have at least 1 iteration - if (n_masking_steps > 1 && n_block <= 0) { - --n_block; - break; - } - } - - // These are the iterations where we don't need masking on S - for (; n_block >= 0; --n_block) { - Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); - flash::cp_async_wait<0>(); - __syncthreads(); - // Advance gV - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); - cute::cp_async_fence(); - - flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K - ); - - flash::cp_async_wait<0>(); - __syncthreads(); - if (n_block > 0) { - // Advance gK - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); - // This cp_async_fence needs to be in the if block, otherwise the synchronization - // isn't right and we get race conditions. - cute::cp_async_fence(); - } - - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); - - Tensor rP = flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; - uint32_t block_col_idx = n_block * (kBlockN / 32); - if (Return_softmax) { - Tensor tOrP_copy = make_fragment_like(tOrP); - copy(tOrP, tOrP_copy); - flash::apply_dropout( - tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, - block_row_idx, block_col_idx, kNWarps - ); - flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); - tPgP.data() = tPgP.data() + (-kBlockN); - } - if (Is_dropout) { - flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, - block_row_idx, block_col_idx, kNWarps); - } - - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); - } - - // Epilogue - - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); - Tensor lse = make_fragment_like(scores_sum); - #pragma unroll - for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { - float sum = scores_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); - float scale = !Is_dropout ? inv_sum : inv_sum * params.rp_dropout; - #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } - } - - // if (cute::thread0()) { print(acc_o_rowcol); } - - // Convert acc_o from fp32 to fp16/bf16 - Tensor rO = flash::convert_type(acc_o); - Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) - // Partition sO to match the accumulator partitioning - auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); - // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); - Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) - Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - // sO has the same size as sQ, so we don't need to sync here. - if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } - - copy(smem_thr_copy_O, taccOrO, taccOsO); - - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) - + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - Shape, Int>{}, - make_stride(params.o_row_stride, _1{})); - Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - Shape>{}, Stride<_1>{}); - - auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx); - Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - - __syncthreads(); - - Tensor tOrO = make_tensor(shape(tOgO)); - copy(gmem_thr_copy_O, tOsO, tOrO); - - Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) - static_assert(decltype(size<0>(taccOcO))::value == 4); - // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. - Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); - CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M - if (get<1>(taccOcO_row(0)) == 0) { - #pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { - const int row = get<0>(taccOcO_row(mi)); - if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); } - } - } - - // Construct identity layout for sO - Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); - if (!Is_even_K) { - #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } - } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM - ); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void compute_attn(const Params ¶ms) { - const int m_block = blockIdx.x; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - - // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting - // them to have the same number of threads or have to traverse the attention matrix - // in the same order. - // In the Philox RNG, we use the offset to store the batch, head, and the lane id - // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within - // the attention matrix. This way, as long as we have the batch, head, and the location of - // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace flash diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_launch_template.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_launch_template.h deleted file mode 100644 index 14b59b9182b5..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/flash_fwd_launch_template.h +++ /dev/null @@ -1,227 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -#pragma once - -#include - -#include "static_switch.h" -#include "flash.h" -#include "flash_fwd_kernel.h" - -template -__global__ void flash_fwd_kernel(Flash_fwd_params params) { - flash::compute_attn(params); -} - -template -void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr size_t smem_size = Kernel_traits::kSmemSize; - // printf("smem_size = %d\n", smem_size); - - // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. - // https://github.com/kokkos/kokkos-kernels/issues/349 - // https://github.com/HazyResearch/flash-attention/issues/21 - - const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; - dim3 grid(num_m_block, params.b, params.h); - // We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check - // for cu_seqlens_q as well. - const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0; - const bool is_even_K = params.d == Kernel_traits::kHeadDim; - const bool return_softmax = params.p_ptr != nullptr; - BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { - BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { - // Will only return softmax if dropout, to reduce compilation time. - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); - }); -} - -template -void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 32; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - }); - }); -} - -template -void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 64; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if constexpr(!Is_dropout) { - // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower - // Using block size (64 x 256) is 27% slower for seqlen=2k - // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - }); - }); -} - -template -void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 96; - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm8x = dprops->major == 8 && dprops->minor > 0; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - - }); - }); -} - -template -void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 128; - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm8x = dprops->major == 8 && dprops->minor > 0; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if constexpr(!Is_dropout) { - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // 1st ones are good for H100, A100 - // 2nd one is good for A6000 bc we get slightly better occupancy - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - }); - }); -} - -template -void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 160; - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm8x = dprops->major == 8 && dprops->minor > 0; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // For A100, H100, 128 x 32 is the fastest. - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - // and 128 x 64 with 8 warps is the fastest for non-causal. - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - }); - }); -} - -template -void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 192; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if constexpr(!Is_dropout) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - }); - }); -} - -template -void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 224; - int device; - cudaGetDevice(&device); - int max_smem_per_block; - cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); - // printf("max_smem_per_block = %d\n", max_smem_per_block); - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32. - // If we have N = 32, there are only 1024 elements to load at once, where each load - // is 8 elements. This means we can only use 128 threads and not 256 threads. - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - }); - }); -} - -template -void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int Headdim = 256; - int device; - cudaGetDevice(&device); - int max_smem_per_sm, max_smem_per_block; - cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); - status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); - // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // For A100, we want to run with 128 x 64 (128KB smem). - // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. - if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // 64 KB - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // 96 KB - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - }); - }); -} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits.h deleted file mode 100644 index 4cf445e38545..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits.h +++ /dev/null @@ -1,366 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -#pragma once - -#include "cute/algorithm/copy.hpp" - -#include "cutlass/cutlass.h" -#include "cutlass/layout/layout.h" -#include - -using namespace cute; - -template -struct Flash_kernel_traits { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - using Element = elem_type; - static constexpr bool Has_cp_async = true; -#else - using Element = cutlass::half_t; - static constexpr bool Has_cp_async = false; -#endif - - using ElementAccum = float; - using index_t = uint32_t; - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - using MMA_Atom_Arch = std::conditional_t< - std::is_same_v, - MMA_Atom, - MMA_Atom - >; - using ValLayoutMNK = Layout>; -#else - using MMA_Atom_Arch = MMA_Atom; - using ValLayoutMNK = Layout>; -#endif - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 - using SmemCopyAtom = Copy_Atom; - using SmemCopyAtomTransposed = Copy_Atom; -#else - using SmemCopyAtom = Copy_Atom; - using SmemCopyAtomTransposed = Copy_Atom; -#endif -}; - -// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true -template > -struct Flash_fwd_kernel_traits : public Base { - using Element = typename Base::Element; - using ElementAccum = typename Base::ElementAccum; - using index_t = typename Base::index_t; - static constexpr bool Has_cp_async = Base::Has_cp_async; - using SmemCopyAtom = typename Base::SmemCopyAtom; - using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; - - static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; - static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; - - // The number of threads. - static constexpr int kNWarps = kNWarps_; - static constexpr int kNThreads = kNWarps * 32; - - static constexpr int kBlockM = kBlockM_; - static constexpr int kBlockN = kBlockN_; - static constexpr int kHeadDim = kHeadDim_; - static_assert(kHeadDim % 32 == 0); - static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; - static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); - static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; - - using TiledMma = TiledMMA< - typename Base::MMA_Atom_Arch, - Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group - typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM - - using SmemLayoutAtomQ = decltype( - composition(Swizzle{}, - // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 - Layout>, - Stride, _1>>{})); - using SmemLayoutQ = decltype(tile_to_shape( - SmemLayoutAtomQ{}, - Shape, Int>{})); - - using SmemLayoutKV = decltype(tile_to_shape( - SmemLayoutAtomQ{}, - Shape, Int>{})); - - using SmemLayoutAtomVtransposed = decltype( - composition(Swizzle{}, - // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 - Layout, Int>, - Stride<_1, Int>>{})); - using SmemLayoutVtransposed = decltype(tile_to_shape( - SmemLayoutAtomVtransposed{}, - Shape, Int>{})); - // Maybe the VtransposeNoSwizzle just needs to have the right shape - // And the strides don't matter? - using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); - - using SmemLayoutAtomO = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride, _1>>{})); - using SmemLayoutO = decltype(tile_to_shape( - SmemLayoutAtomO{}, - Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; - - static constexpr int kSmemQCount = size(SmemLayoutQ{}); - static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; - static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); - static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); - static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; - - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); - // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. - // For example, for d=128, smem is split into 2 "pages", each page takes care of columns - // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, - // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, - // to the same banks. - static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = Layout, Int>, - Stride, _1>>; - - // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading - // from the same address by the same threadblock. This is slightly faster. - using Gmem_copy_struct = std::conditional_t< - Has_cp_async, - SM80_CP_ASYNC_CACHEGLOBAL, - DefaultCopy - >; - using GmemTiledCopyQKV = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per read - using GmemTiledCopyO = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per store - static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); - using GmemLayoutAtomP = Layout, Int>, - Stride, _1>>; - - using GmemTiledCopyP = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtomP{}, - Layout>{})); // Val layout, 8 vals per store - -}; - -// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. -// No_double_buffer is another option to reduce smem usage, but will slow things down. -template > -struct Flash_bwd_kernel_traits : public Base { - using Element = typename Base::Element; - using ElementAccum = typename Base::ElementAccum; - using index_t = typename Base::index_t; - static constexpr bool Has_cp_async = Base::Has_cp_async; - using SmemCopyAtom = typename Base::SmemCopyAtom; - using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; - - static constexpr bool Is_V_in_regs = Is_V_in_regs_; - static constexpr bool No_double_buffer = No_double_buffer_; - - // The number of threads. - static constexpr int kNWarps = kNWarps_; - static constexpr int kNThreads = kNWarps * 32; - - static constexpr int kBlockM = kBlockM_; - static constexpr int kBlockN = kBlockN_; - static constexpr int kHeadDim = kHeadDim_; - static_assert(kHeadDim % 32 == 0); - static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; - static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); - static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; - - static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; - static_assert(kNWarps % AtomLayoutMSdP == 0); - static_assert(kNWarps % AtomLayoutNdKV == 0); - static_assert(kNWarps % AtomLayoutMdQ == 0); - - using TiledMmaSdP = TiledMMA< - typename Base::MMA_Atom_Arch, - Layout, Int, _1>>, - typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM - - using TiledMmadKV = TiledMMA< - typename Base::MMA_Atom_Arch, - Layout, Int, _1>>, - typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM - - using TiledMmadQ = TiledMMA< - typename Base::MMA_Atom_Arch, - Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group - typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM - - using SmemLayoutAtomQdO = decltype( - composition(Swizzle{}, - Layout>, - Stride, _1>>{})); - using SmemLayoutQdO = decltype(tile_to_shape( - SmemLayoutAtomQdO{}, - make_shape(Int{}, Int{}))); - - using SmemLayoutAtomKV = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride, _1>>{})); - using SmemLayoutKV = decltype(tile_to_shape( - // SmemLayoutAtomQdO{}, - SmemLayoutAtomKV{}, - make_shape(Int{}, Int{}))); - - using SmemLayoutAtomKtransposed = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride<_1, Int>>{})); - using SmemLayoutKtransposed = decltype(tile_to_shape( - SmemLayoutAtomKtransposed{}, - make_shape(Int{}, Int{}))); - // Maybe the KtransposeNoSwizzle just needs to have the right shape - // And the strides don't matter? - using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); - - // TODO: generalize to other values of kBlockN - // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 - // static constexpr int kPBlockN = kBlockN; - static_assert(kBlockN >= 64); - // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. - static constexpr int kPBlockN = 64; - static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); - // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); - static constexpr int kSwizzlePdS = 3; - using SmemLayoutAtomPdS = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride, _1>>{})); - using SmemLayoutPdS = decltype(tile_to_shape( - SmemLayoutAtomPdS{}, - make_shape(Int{}, Int{}))); - using SmemLayoutAtomPdStransposed = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride<_1, Int>>{})); - using SmemLayoutPdStransposed = decltype(tile_to_shape( - SmemLayoutAtomPdStransposed{}, - make_shape(Int{}, Int{}))); - using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); - using SmemCopyAtomPdS = Copy_Atom; - - using SmemLayoutAtomQdOtransposed = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride<_1, Int>>{})); - using SmemLayoutQdOtransposed = decltype(tile_to_shape( - SmemLayoutAtomQdOtransposed{}, - make_shape(Int{}, Int{}))); - using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); - - using SmemLayoutAtomdKV = decltype( - composition(Swizzle{}, - Layout>, - Stride, _1>>{})); - using SmemLayoutdKV = decltype(tile_to_shape( - SmemLayoutAtomdKV{}, - make_shape(Int{}, Int{}))); - using SmemCopyAtomdKV = Copy_Atom; - - using SmemLayoutAtomdQ = decltype( - composition(Swizzle{}, - Layout>, - Stride, _1>>{})); - using SmemLayoutdQ = decltype(tile_to_shape( - SmemLayoutAtomdQ{}, - make_shape(Int{}, Int{}))); - using SmemCopyAtomdQ = Copy_Atom; - - static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ - static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; - static constexpr int kSmemdSCount = size(SmemLayoutPdS{}); - static constexpr int kSmemPCount = size(SmemLayoutPdS{}); - static constexpr int kSmemdQCount = size(SmemLayoutdQ{}); - static constexpr int kSmemdPsumCount = kBlockM; - static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); - static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); - static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); - static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); - static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); - static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); - static constexpr int kSmemSize = kSmemQdOSize - + (!Is_V_in_regs - ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) - : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); - static constexpr int kSmemSize1colblock = kSmemQdOSize - + (!Is_V_in_regs - ? kSmemKVSize + kSmemdSSize + kSmemPSize - : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); - static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3 - + kSmemdSSize + kSmemPSize; - - - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); - // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem - // to affect speed in practice. - static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = Layout, Int>, - Stride, _1>>; - - // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading - // from the same address by the same threadblock. This is slightly faster. - using Gmem_copy_struct = std::conditional_t< - Has_cp_async, - SM80_CP_ASYNC_CACHEGLOBAL, - DefaultCopy - >; - using GmemTiledCopyQKV = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per read - using GmemTiledCopydO = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per store - using GmemTiledCopydKV = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per store - using GmemTiledCopydQ = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per store - using GmemLayoutAtomdQaccum = std::conditional_t< - kBlockKSmem == 32, - Layout, // Thread layout, 8 threads per row - Stride< _8, _1>>, - Layout, // Thread layout, 16 threads per row - Stride< _16, _1>> - >; - using GmemTiledCopydQaccum = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtomdQaccum{}, - Layout>{})); // Val layout, 4 vals per store - - using GmemTiledCopydQaccumAtomicAdd = decltype( - make_tiled_copy(Copy_Atom{}, - Layout, // Thread layout, 8 threads per row - Stride<_32, _1>>{}, - Layout>{})); // Val layout, 1 val per store - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits_sm90.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits_sm90.h deleted file mode 100644 index c8b61ae2d741..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/kernel_traits_sm90.h +++ /dev/null @@ -1,159 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -#pragma once - -#include "cute/algorithm/copy.hpp" - -#include "cutlass/cutlass.h" -#include "cutlass/layout/layout.h" -#include - -using namespace cute; - -template -struct Flash_kernel_traits_sm90 { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - using Element = elem_type; - static constexpr bool Has_cp_async = true; -#else - using Element = cutlass::half_t; - static constexpr bool Has_cp_async = false; -#endif - - using ElementAccum = float; - using index_t = uint32_t; - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - using MMA_Atom_Arch = std::conditional_t< - std::is_same_v, - MMA_Atom, - MMA_Atom - >; - using ValLayoutMNK = Layout>; -#else - using MMA_Atom_Arch = MMA_Atom; - using ValLayoutMNK = Layout>; -#endif - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 - using SmemCopyAtom = Copy_Atom; - using SmemCopyAtomTransposed = Copy_Atom; -#else - using SmemCopyAtom = Copy_Atom; - using SmemCopyAtomTransposed = Copy_Atom; -#endif -}; - -template > -struct Flash_fwd_kernel_traits : public Base { - using Element = typename Base::Element; - using ElementAccum = typename Base::ElementAccum; - using index_t = typename Base::index_t; - static constexpr bool Has_cp_async = Base::Has_cp_async; - using SmemCopyAtom = typename Base::SmemCopyAtom; - using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; - - static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; - static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; - - // The number of threads. - static constexpr int kNWarps = kNWarps_; - static constexpr int kNThreads = kNWarps * 32; - - static constexpr int kBlockM = kBlockM_; - static constexpr int kBlockN = kBlockN_; - static constexpr int kHeadDim = kHeadDim_; - static_assert(kHeadDim % 32 == 0); - static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; - static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); - static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; - - using TiledMma = TiledMMA< - typename Base::MMA_Atom_Arch, - Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group - typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM - - using SmemLayoutAtomQ = decltype( - composition(Swizzle{}, - // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 - Layout>, - Stride, _1>>{})); - using SmemLayoutQ = decltype(tile_to_shape( - SmemLayoutAtomQ{}, - Shape, Int>{})); - - using SmemLayoutKV = decltype(tile_to_shape( - SmemLayoutAtomQ{}, - Shape, Int>{})); - - using SmemLayoutAtomVtransposed = decltype( - composition(Swizzle{}, - // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 - Layout, Int>, - Stride<_1, Int>>{})); - using SmemLayoutVtransposed = decltype(tile_to_shape( - SmemLayoutAtomVtransposed{}, - Shape, Int>{})); - // Maybe the VtransposeNoSwizzle just needs to have the right shape - // And the strides don't matter? - using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); - - using SmemLayoutAtomO = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride, _1>>{})); - using SmemLayoutO = decltype(tile_to_shape( - SmemLayoutAtomO{}, - Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; - - static constexpr int kSmemQCount = size(SmemLayoutQ{}); - static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; - static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); - static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); - static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; - - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); - // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. - // For example, for d=128, smem is split into 2 "pages", each page takes care of columns - // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, - // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, - // to the same banks. - static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = Layout, Int>, - Stride, _1>>; - - // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading - // from the same address by the same threadblock. This is slightly faster. - using Gmem_copy_struct = std::conditional_t< - Has_cp_async, - SM80_CP_ASYNC_CACHEGLOBAL, - DefaultCopy - >; - using GmemTiledCopyQKV = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per read - using GmemTiledCopyO = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per store - static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); - using GmemLayoutAtomP = Layout, Int>, - Stride, _1>>; - - using GmemTiledCopyP = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtomP{}, - Layout>{})); // Val layout, 8 vals per store - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/philox.cuh b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/philox.cuh deleted file mode 100644 index c54ea0cd14dd..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/philox.cuh +++ /dev/null @@ -1,168 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -#pragma once -// Philox CUDA. - -namespace flash { - -struct ull2 { - unsigned long long x; - unsigned long long y; -}; - -inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { - uint2 *res; - unsigned long long tmp; - asm ("mul.wide.u32 %0, %1, %2;\n\t" - : "=l"(tmp) - : "r"(a), "r"(b)); - res = (uint2*)(&tmp); - return *res; -} - -inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { - constexpr unsigned long kPhiloxSA = 0xD2511F53; - constexpr unsigned long kPhiloxSB = 0xCD9E8D57; - uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); - uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); - uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; - return ret; -} - -inline __device__ uint4 philox(unsigned long long seed, - unsigned long long subsequence, - unsigned long long offset) { - constexpr unsigned long kPhilox10A = 0x9E3779B9; - constexpr unsigned long kPhilox10B = 0xBB67AE85; - uint2 key = reinterpret_cast(seed); - uint4 counter; - ull2 *tmp = reinterpret_cast(&counter); - tmp->x = offset; - tmp->y = subsequence; - #pragma unroll - for (int i = 0; i < 6; i++) { - counter = philox_single_round(counter, key); - key.x += (kPhilox10A); - key.y += (kPhilox10B); - } - uint4 output = philox_single_round(counter, key); - return output; -} - -} // namespace flash - -namespace { - -class Philox { -public: - __device__ inline Philox(unsigned long long seed, - unsigned long long subsequence, - unsigned long long offset) - : STATE(0) - , seed_(seed) - , offset_(offset) - , key(reinterpret_cast(seed)) { - //key.x = (unsigned int)seed; - //key.y = (unsigned int)(seed >> 32); - //counter = make_uint4(0, 0, 0, 0); - //counter.z = (unsigned int)(subsequence); - //counter.w = (unsigned int)(subsequence >> 32); - //STATE = 0; - //incr_n(offset / 4); - - // key = reinterpret_cast(seed); - ull2 * tmp = reinterpret_cast(&counter); - tmp->x = offset / 4; - tmp->y = subsequence; - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w); - // } - } - __device__ inline uint4 operator()() { - // // if (STATE == 0) { - // uint4 counter_ = counter; - // uint2 key_ = key; - // // 7-round philox - // #pragma unroll - // for (int i = 0; i < 6; i++) { - // counter_ = flash::philox_single_round(counter_, key_); - // key_.x += (kPhilox10A); - // key_.y += (kPhilox10B); - // } - // // output = philox_single_round(counter_, key_); - // uint4 output = flash::philox_single_round(counter_, key_); - // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); - // // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w); - // // } - // incr(); - // // } - // // return a float4 directly - // // unsigned long ret; - // // switch(STATE) { - // // case 0: ret = output.x; break; - // // case 1: ret = output.y; break; - // // case 2: ret = output.z; break; - // // case 3: ret = output.w; break; - // //} - // // STATE = (STATE + 1) % 4; - // return output; - return flash::philox(seed_, offset_, offset_); - } - -private: - unsigned long long offset_, seed_; - struct ull2 { - uint64_t x; - uint64_t y; - }; - uint4 counter; - // uint4 output; - const uint2 key; - unsigned int STATE; - __device__ inline void incr_n(unsigned long long n) { - unsigned int nlo = (unsigned int)(n); - unsigned int nhi = (unsigned int)(n >> 32); - counter.x += nlo; - if (counter.x < nlo) - nhi++; - counter.y += nhi; - if (nhi <= counter.y) - return; - if (++counter.z) - return; - ++counter.w; - } - - __device__ uint4 incr128 (uint4 ctr) - { - uint4 res; - asm ("add.cc.u32 %0, %4, %8;\n\t" - "addc.cc.u32 %1, %5, %9;\n\t" - "addc.cc.u32 %2, %6, %10;\n\t" - "addc.u32 %3, %7, %11;\n\t" - : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w) - : "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w), - "n"(1), "n"(0), "n"(0), "n"(0)); - return res; - } - - __device__ inline void incr() { - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); - // } - counter = incr128(counter); - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); - // } - } - - static const unsigned long kPhilox10A = 0x9E3779B9; - static const unsigned long kPhilox10B = 0xBB67AE85; - // static const unsigned long kPhiloxSA = 0xD2511F53; - // static const unsigned long kPhiloxSB = 0xCD9E8D57; -}; - -} // namespace diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/softmax.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/softmax.h deleted file mode 100644 index 76200be8c774..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/softmax.h +++ /dev/null @@ -1,272 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -#pragma once - -#include - -#include - -#include -#include - -#include "philox.cuh" -#include "utils.h" - -namespace flash { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ inline void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); mi++) { - summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); - #pragma unroll - for (int ni = 1; ni < size<1>(tensor); ni++) { - summary(mi) = op(summary(mi), tensor(mi, ni)); - } - } -} - -template -__device__ inline void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { - CUTE_STATIC_ASSERT_V(size(dst) == size(src)); - #pragma unroll - for (int i = 0; i < size(dst); i++){ - dst(i) = Allreduce<4>::run(src(i), op); - } -} - -template -__device__ inline void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { - thread_reduce_(tensor, summary, op); - quad_allreduce_(summary, summary, op); -} - -template -__device__ inline void reduce_max(Tensor const& tensor, Tensor &max){ - MaxOp max_op; - reduce_(tensor, max, max_op); -} - -template -__device__ inline void reduce_sum(Tensor const& tensor, Tensor &sum){ - SumOp sum_op; - reduce_(tensor, sum, sum_op); -} - -// Apply the exp to all the elements. -template -inline __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - // If max is -inf, then all elements must have been -inf (possibly due to masking). - // We don't want (-inf - (-inf)) since that would give NaN. - // If we don't have float around M_LOG2E the multiplication is done in fp64. - const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); - #pragma unroll - for (int ni = 0; ni < size<1>(tensor); ++ni) { - // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - // max * log_2(e)) This allows the compiler to use the ffma - // instruction instead of fadd and fmul separately. - tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); - } - } -} - -// Apply the exp to all the elements. -template -inline __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - MaxOp max_op; - max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); - #pragma unroll - for (int ni = 1; ni < size<1>(tensor); ni++) { - max(mi) = max_op(max(mi), tensor(mi, ni)); - } - max(mi) = Allreduce<4>::run(max(mi), max_op); - // If max is -inf, then all elements must have been -inf (possibly due to masking). - // We don't want (-inf - (-inf)) since that would give NaN. - const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; - sum(mi) = 0; - #pragma unroll - for (int ni = 0; ni < size<1>(tensor); ++ni) { - // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - // max * log_2(e)) This allows the compiler to use the ffma - // instruction instead of fadd and fmul separately. - tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); - sum(mi) += tensor(mi, ni); - } - SumOp sum_op; - sum(mi) = Allreduce<4>::run(sum(mi), sum_op); - } -} - -template -inline __device__ void apply_mask(Tensor &tensor, const uint32_t max_seqlen_k) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) - static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const uint32_t lane_id = threadIdx.x % 32; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2; - if (col_idx >= max_seqlen_k) { - // Without the "make_coord" we get wrong results - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - tensor(mi, make_coord(j, nj)) = -INFINITY; - } - } - } - } -} - -template -inline __device__ void apply_mask_causal(Tensor &tensor, const uint32_t col_idx_offset_, - const uint32_t max_seqlen_k, const uint32_t row_idx_offset_, - const uint32_t warp_row_stride) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) - static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const uint32_t lane_id = threadIdx.x % 32; - // const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4; - const uint32_t row_idx_offset = row_idx_offset_; - const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const uint32_t row_idx = row_idx_base + i * 8; - const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1); - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const uint32_t col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const uint32_t col_idx = col_idx_base + j; - if (col_idx >= col_idx_limit) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; - } - } - } - // if (cute::thread0()) { - // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); - // print(tensor(make_coord(i, mi), _)); - // // print(tensor(_, j + nj * size<1, 0>(tensor))); - // } - } - } -} - -template -inline __device__ void apply_mask_causal_w_idx( - Tensor &tensor, Tensor const &idx_rowcol, - const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_) -{ - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 2, "Only support 2D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); - CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); - #pragma unroll - for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { - if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { - tensor(mi, ni) = -INFINITY; - } - } - // if (cute::thread0()) { - // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); - // print(tensor(_, make_coord(j, ni))); - // // print(tensor(_, j + ni * size<1, 0>(tensor))); - // } - } -} - -template -inline __device__ void apply_dropout(Tensor &tensor, uint8_t p_dropout_in_uint8_t, - unsigned long long seed, unsigned long long offset, - uint32_t block_row_start, uint32_t block_col_start, - uint32_t block_row_stride) { - // tensor has shape (8, MMA_M, MMA_N / 2) - using T = typename Engine::value_type; - auto encode_dropout = [](bool keep, T val) { - return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); - }; - static_assert(decltype(size<2>(tensor))::value % 2 == 0); - const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); - const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); - // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } - #pragma unroll - for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { - uint2 rowcol = make_uint2(block_row_start, block_col_start); - #pragma unroll - for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { - // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} - uint4 random_uint4 = flash::philox(seed, reinterpret_cast(rowcol), offset); - // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} - uint8_t (&rnd_8)[16] = reinterpret_cast(random_uint4); - // Special implementation for 16-bit types: we duplicate the threshold to the - // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction - // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000, - // and the high 16 bits will be either 0xffff or 0x0000, depending on whether - // the random value is less than the threshold. - // We then do a bit-wise AND between the mask and the original value (in 32-bit). - // We're exploiting the fact that floating point comparison is equivalent to integer - // comparison, since we're comparing unsigned integers whose top 8-bits are zero. - if (!encode_dropout_in_sign_bit - && (std::is_same::value || std::is_same::value)) { - uint16_t rnd_16[16]; - #pragma unroll - for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } - uint32_t (&rnd_32)[8] = reinterpret_cast(rnd_16); - #pragma unroll - for (int j = 0; j < 2; j++) { - Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); - // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - #pragma unroll - for (int i = 0; i < 4; i++) { - uint32_t mask; - asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); - tensor_uint32(i) &= mask; - } - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - } - } else { - #pragma unroll - for (int j = 0; j < 2; j++) { - #pragma unroll - for (int i = 0; i < 8; i++) { - tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); - } - Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - } - } - // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w); - // // } - } - } -} - -} // namespace flash diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/static_switch.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/static_switch.h deleted file mode 100644 index 4aa847402886..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/static_switch.h +++ /dev/null @@ -1,66 +0,0 @@ -// Inspired by -// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h -// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h - -#pragma once - -/// @param COND - a boolean expression to switch by -/// @param CONST_NAME - a name given for the constexpr bool variable. -/// @param ... - code to execute for true and false -/// -/// Usage: -/// ``` -/// BOOL_SWITCH(flag, BoolConst, [&] { -/// some_function(...); -/// }); -/// ``` -#define BOOL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - if (COND) { \ - constexpr static bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr static bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() - -#define FP16_SWITCH(COND, ...) \ - [&] { \ - if (COND) { \ - using elem_type = cutlass::half_t; \ - return __VA_ARGS__(); \ - } else { \ - using elem_type = cutlass::bfloat16_t; \ - return __VA_ARGS__(); \ - } \ - }() - -#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ - [&] { \ - if (HEADDIM <= 32) { \ - constexpr static int kHeadDim = 32; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 64) { \ - constexpr static int kHeadDim = 64; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 96) { \ - constexpr static int kHeadDim = 96; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 128) { \ - constexpr static int kHeadDim = 128; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 160) { \ - constexpr static int kHeadDim = 160; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 192) { \ - constexpr static int kHeadDim = 192; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 224) { \ - constexpr static int kHeadDim = 224; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 256) { \ - constexpr static int kHeadDim = 256; \ - return __VA_ARGS__(); \ - } \ - }() diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/utils.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/utils.h deleted file mode 100644 index a75ae2ca25e8..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/flash_attn/src/utils.h +++ /dev/null @@ -1,388 +0,0 @@ -/****************************************************************************** - * The following codes are modified from the original FlashAttn library: https://github.com/Dao-AILab/flash-attention - ******************************************************************************/ - -#pragma once - -#include -#include -#include - -#include - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#include -#endif - -#include -#include - -#include -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace flash { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ uint32_t relu2(const uint32_t x); - -template<> -inline __device__ uint32_t relu2(const uint32_t x) { - uint32_t res; - const uint32_t zero = 0u; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); -#else - asm volatile( \ - "{\n" \ - "\t .reg .f16x2 sela;\n" \ - "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ - "\t and.b32 %0, sela, %1;\n" - "}\n" : "=r"(res) : "r"(x), "r"(zero)); -#endif - return res; -} - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -template<> -inline __device__ uint32_t relu2(const uint32_t x) { - uint32_t res; - const uint32_t zero = 0u; - asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); - return res; -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -template -inline __device__ uint32_t convert_relu2(const float2 x); - -template<> -inline __device__ uint32_t convert_relu2(const float2 x) { - uint32_t res; - const uint32_t a = reinterpret_cast(x.x); - const uint32_t b = reinterpret_cast(x.y); - asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); - return res; -} - -template<> -inline __device__ uint32_t convert_relu2(const float2 x) { - uint32_t res; - const uint32_t a = reinterpret_cast(x.x); - const uint32_t b = reinterpret_cast(x.y); - asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); - return res; -} - -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float2 half2_unpack(uint32_t a); - -template <> -inline __device__ float2 half2_unpack<__half>(uint32_t a) { - return __half22float2(reinterpret_cast<__half2 (&)>(a)); -} - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -template <> -inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) { - return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a)); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert two half2's or bf162's into float, then take their dot product. -template -inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) { - float2 af = flash::half2_unpack(a); - float2 bf = flash::half2_unpack(b); - return af.x * bf.x + af.y * bf.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Converted two vectors of 8 half's or bf16's into float, then take their dot product. -template -inline __device__ float hmulsum8(const uint4 a, const uint4 b) { - float sum; - sum = flash::hfma2_to_float(a.x, b.x); - sum += flash::hfma2_to_float(a.y, b.y); - sum += flash::hfma2_to_float(a.z, b.z); - sum += flash::hfma2_to_float(a.w, b.w); - return sum; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MaxOp { -__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } -}; - -template <> -struct MaxOp { -// This is slightly faster -__device__ inline float operator()(float const &x, float const &y) { return max(x, y); } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct SumOp { -__device__ inline T operator()(T const & x, T const & y) { return x + y; } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Allreduce { - static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); - template - static __device__ inline T run(T x, Operator &op) { - constexpr int OFFSET = THREADS / 2; - x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); - return Allreduce::run(x, op); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Allreduce<2> { -template -static __device__ inline T run(T x, Operator &op) { - x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); - return x; -} -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, - Tensor4 const& tCsB, TiledMma tiled_mma, - TiledCopy0 smem_thr_copy_A, TiledCopy1 smem_thr_copy_B) { - CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M - CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N - CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K - Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M - Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N - if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } - if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } - #pragma unroll - for (int i = 0; i < size<2>(tCrA); ++i) { - if (i < size<2>(tCrA) - 1) { - if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } - if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } - } - cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, - TiledMma tiled_mma, TiledCopy smem_thr_copy_B) { - CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M - CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N - CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K - Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N - copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); - #pragma unroll - for (int i = 0; i < size<2>(tCrA); ++i) { - if (i < size<2>(tCrA) - 1) { - copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); - } - cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) -template -inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { - static_assert(decltype(size<0>(acc_layout))::value == 4); - static_assert(decltype(rank(acc_layout))::value == 3); - auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) -// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. -template -inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { - using X = Underscore; - static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); - static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); - constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); - static_assert(mma_shape_K == 8 || mma_shape_K == 16); - constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; - auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) - return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), - get<0, 1>(l), - get<1, 1, 1>(l)); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ auto convert_type(Tensor const &tensor) { - using From_type = typename Engine::value_type; - constexpr int numel = decltype(size(tensor))::value; - cutlass::NumericArrayConverter convert_op; - // HACK: this requires tensor to be "contiguous" - auto frag = convert_op(*reinterpret_cast *>(tensor.data())); - return make_tensor(make_rmem_ptr(&frag), tensor.layout()); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void relu_(Tensor &tensor) { - constexpr int numel = decltype(size(tensor))::value; - static_assert(numel % 2 == 0); - using value_t = typename Engine::value_type; - // HACK: this requires tensor to be "contiguous" - Tensor tensor_uint32 = recast(tensor); - #pragma unroll - for (int i = 0; i < size(tensor_uint32); ++i) { - tensor_uint32(i) = relu2(tensor_uint32(i)); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction -template -inline __device__ auto convert_type_relu(Tensor const &tensor) { - using From_type = typename Engine::value_type; - static_assert(std::is_same_v || std::is_same_v); - static_assert(std::is_same_v); - constexpr int numel = decltype(size(tensor))::value; - static_assert(numel % 2 == 0); -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - // HACK: this requires tensor to be "contiguous" - Tensor tensor_float2 = recast(tensor); - Tensor out_uint32 = make_tensor(tensor_float2.layout()); - #pragma unroll - for (int i = 0; i < size(out_uint32); ++i) { - out_uint32(i) = convert_relu2(tensor_float2(i)); - } - Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); -#else - Tensor out = flash::convert_type(tensor); - flash::relu_(out); -#endif - return out; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Blocks until all but N previous cp.async.commit_group operations have committed. -// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all -// (which is equivalent to commit_group then wait_group 0). -// Instead we just call cp.async.wait_group 0, which is slightly faster. -// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 -template -CUTE_HOST_DEVICE -void cp_async_wait() { -#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) - asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void copy(TiledCopy thr_copy, Tensor const &S, - Tensor &D, Tensor const &identity_MN, - Tensor const &predicate_K, int max_MN=0) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - // There's no case where !Clear_OOB_K && Clear_OOB_MN - static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || predicate_K(k)) { - copy(thr_copy, S(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - clear(D(_, m, k)); - } - } - } else if (Clear_OOB_MN) { - clear(D(_, m, _)); - } - } - // TD [2023-04-13]: Strange that the code below can cause race condition. - // I think it's because the copies are under an if statement. - // if (Is_even_K) { - // #pragma unroll - // for (int m = 0; m < size<1>(S); ++m) { - // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { - // copy(thr_copy, S(_, m, _), D(_, m, _)); - // } else if (Clear_OOB_MN) { - // clear(D(_, m, _)); - // } - // } - // } else { // It's slightly faster in this case if iterate over K first - // #pragma unroll - // for (int k = 0; k < size<2>(S); ++k) { - // if (predicate_K(k)) { - // #pragma unroll - // for (int m = 0; m < size<1>(S); ++m) { - // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { - // copy(thr_copy, S(_, m, k), D(_, m, k)); - // } else if (Clear_OOB_MN) { - // clear(D(_, m, k)); - // } - // } - // } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN - // if (Clear_OOB_MN || Is_even_MN) { - // clear(D(_, _, k)); - // } else { - // #pragma unroll - // for (int m = 0; m < size<1>(S); ++m) { - // if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) { - // clear(D(_, m, k)); - // } - // } - // } - // } - // } - // } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace flash diff --git a/colossalai/kernel/cuda_native/setup.py b/colossalai/kernel/cuda_native/setup.py index 1be6690b46ac..4d61b344a2cd 100644 --- a/colossalai/kernel/cuda_native/setup.py +++ b/colossalai/kernel/cuda_native/setup.py @@ -159,54 +159,6 @@ def append_nvcc_threads(nvcc_extra_args): ], ), - CUDAExtension( - name="col_flash_attn_2_lib", - sources=[ - "csrc/attention_infer_kernels/flash_attn/flash_api.cpp", - "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", - "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", - "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", - "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", - "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", - "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", - "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", - "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", - "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", - "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", - "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", - "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", - "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu", - "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu", - "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", - "csrc/attention_infer_kernels/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", - ], - extra_compile_args={ - "cxx": ["-O3", "-std=c++17"] + generator_flag, - "nvcc": append_nvcc_threads( - [ - "-O3", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "--ptxas-options=-v", - "-lineinfo" - ] - + generator_flag - + cc_flag - ), - }, - include_dirs=[ - Path(this_dir) / 'csrc'/'attention_infer_kernels'/'flash_attn' , - Path(this_dir) / 'csrc'/ 'attention_infer_kernels'/'flash_attn' / 'src', - Path(this_dir) / 'csrc'/'cutlass' / 'include', - ], - ), - ], cmdclass={ 'build_ext': BuildExtension From c5249257f0c9bf57fc38eb14a9755243b3c3d245 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 15 Aug 2023 17:49:43 +0800 Subject: [PATCH 14/31] remove cutlass --- colossalai/kernel/cuda_native/csrc/cutlass | 1 - 1 file changed, 1 deletion(-) delete mode 160000 colossalai/kernel/cuda_native/csrc/cutlass diff --git a/colossalai/kernel/cuda_native/csrc/cutlass b/colossalai/kernel/cuda_native/csrc/cutlass deleted file mode 160000 index c4f6b8c6bc94..000000000000 --- a/colossalai/kernel/cuda_native/csrc/cutlass +++ /dev/null @@ -1 +0,0 @@ -Subproject commit c4f6b8c6bc94ff69048492fb34df0dfaf1983933 From 8f515491344f11631b678d06c7568ecb49b40ae7 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 15 Aug 2023 17:54:19 +0800 Subject: [PATCH 15/31] delete cutlass --- .gitmodules | 3 --- 1 file changed, 3 deletions(-) diff --git a/.gitmodules b/.gitmodules index 9c4a5ae34744..2f1c34298a50 100644 --- a/.gitmodules +++ b/.gitmodules @@ -5,6 +5,3 @@ [submodule "examples/tutorial/fastfold/FastFold"] path = examples/tutorial/fastfold/FastFold url = https://github.com/hpcaitech/FastFold -[submodule "colossalai/kernel/cuda_native/csrc/cutlass"] - path = colossalai/kernel/cuda_native/csrc/cutlass - url = https://github.com/NVIDIA/cutlass From 9efe2e9ed47361b977f77135d5fa5b467b66f5f5 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 15 Aug 2023 20:32:35 +0800 Subject: [PATCH 16/31] cleaned codes --- .../llama_infer_cuda.py | 45 +++++-------------- setup.py | 6 +++ 2 files changed, 17 insertions(+), 34 deletions(-) rename colossalai/kernel/cuda_native/setup.py => op_builder/llama_infer_cuda.py (76%) diff --git a/colossalai/kernel/cuda_native/setup.py b/op_builder/llama_infer_cuda.py similarity index 76% rename from colossalai/kernel/cuda_native/setup.py rename to op_builder/llama_infer_cuda.py index 4d61b344a2cd..8d635b4da216 100644 --- a/colossalai/kernel/cuda_native/setup.py +++ b/op_builder/llama_infer_cuda.py @@ -17,8 +17,7 @@ # ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) - +this_dir = os.getcwd() def get_cuda_bare_metal_version(cuda_dir): raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) @@ -29,24 +28,6 @@ def get_cuda_bare_metal_version(cuda_dir): return raw_output, bare_metal_version -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) - torch_binary_version = parse(torch.version.cuda) - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_version != torch_binary_version): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pytorch binaries. " - "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - def raise_if_cuda_home_none(global_option: str) -> None: if CUDA_HOME is not None: return @@ -117,14 +98,12 @@ def append_nvcc_threads(nvcc_extra_args): cc_flag.append("-gencode") cc_flag.append("arch=compute_90,code=sm_90") -setup( - name='colossal-cuda-infer-kernels', - ext_modules=[ +llama_cuda_submodules = [ CUDAExtension( name='col_fused_softmax_lib', sources=[ - 'csrc/attention_infer_kernels/softmax/fused_softmax.cpp', - 'csrc/attention_infer_kernels/softmax/scaled_masked_softmax_cuda.cu' + this_dir + '/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/fused_softmax.cpp', + this_dir + '/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/scaled_masked_softmax_cuda.cu' ], extra_compile_args={ 'cxx': ['-O3',], @@ -135,8 +114,8 @@ def append_nvcc_threads(nvcc_extra_args): CUDAExtension( name="col_pos_encoding_ops", sources=[ - "csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp", - "csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu" + this_dir + "/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp", + this_dir + "/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu" ], extra_compile_args={ 'cxx': ['-O3',], @@ -147,22 +126,20 @@ def append_nvcc_threads(nvcc_extra_args): CUDAExtension( name="col_rms_norm_ops", sources=[ - "csrc/attention_infer_kernels/rmsnorm/layernorm.cpp", - "csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu" + this_dir + "/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm.cpp", + this_dir + "/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu" ], extra_compile_args={ 'cxx': ['-O3',], 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) }, include_dirs=[ - Path(this_dir)/'csrc'/'attention_infer_kernels'/'rmsnorm', + this_dir + '/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm', ], ), - ], - cmdclass={ - 'build_ext': BuildExtension -}) + ] + diff --git a/setup.py b/setup.py index 5d8f831218d9..6002dd487361 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ THIS_DIR = os.path.dirname(os.path.abspath(__file__)) BUILD_CUDA_EXT = int(os.environ.get('CUDA_EXT', '0')) == 1 IS_NIGHTLY = int(os.environ.get('NIGHTLY', '0')) == 1 +LLAMA_INFER_CUDA = int(os.environ.get("LLAMA_INFER_CUDA", "0")) == 1 # a variable to store the op builder ext_modules = [] @@ -138,6 +139,11 @@ def get_version() -> str: op_name_list = ', '.join(op_names) print(f"[extension] loaded builders for {op_name_list}") +if LLAMA_INFER_CUDA: + from op_builder.llama_infer_cuda import llama_cuda_submodules + for sub_module in llama_cuda_submodules: + ext_modules.append(sub_module) + # always put not nightly branch as the if branch # otherwise github will treat colossalai-nightly as the project name # and it will mess up with the dependency graph insights From caefcba0d0bab61a59916af3efa6c4d489fda311 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 15 Aug 2023 21:52:59 +0800 Subject: [PATCH 17/31] add --- op_builder/llama_infer_cuda.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/op_builder/llama_infer_cuda.py b/op_builder/llama_infer_cuda.py index 8d635b4da216..15fb9077796a 100644 --- a/op_builder/llama_infer_cuda.py +++ b/op_builder/llama_infer_cuda.py @@ -1,4 +1,3 @@ -# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py import sys import warnings import os @@ -70,21 +69,6 @@ def append_nvcc_threads(nvcc_extra_args): os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) - -cmdclass = {} -ext_modules = [] - -# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h -# See https://github.com/pytorch/pytorch/pull/70650 -generator_flag = [] -torch_dir = torch.__path__[0] -if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - -raise_if_cuda_home_none("flash_attn") # Check, if CUDA11 is installed for compute capability 8.0 cc_flag = [] _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) From ad3fa468ca7ae47fd5a0acedfc42be67102ff1b0 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 15 Aug 2023 23:05:45 +0800 Subject: [PATCH 18/31] add norm --- op_builder/llama_infer_cuda.py | 26 +------------------ tests/test_kernels/cuda/test_rmsnorm.py | 7 ++--- .../cuda/test_rotary_embedding.py | 7 ++--- tests/test_kernels/cuda/test_softmax.py | 23 ++++++++-------- 4 files changed, 21 insertions(+), 42 deletions(-) diff --git a/op_builder/llama_infer_cuda.py b/op_builder/llama_infer_cuda.py index 15fb9077796a..1fd4f432eee9 100644 --- a/op_builder/llama_infer_cuda.py +++ b/op_builder/llama_infer_cuda.py @@ -26,17 +26,6 @@ def get_cuda_bare_metal_version(cuda_dir): return raw_output, bare_metal_version - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) - - def append_nvcc_threads(nvcc_extra_args): _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version >= Version("11.2"): @@ -45,18 +34,6 @@ def append_nvcc_threads(nvcc_extra_args): if not torch.cuda.is_available(): - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" - "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" - "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version >= Version("11.8"): @@ -74,8 +51,7 @@ def append_nvcc_threads(nvcc_extra_args): _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version < Version("11.0"): raise RuntimeError("FlashAttention is only supported on CUDA 11 and above") -# cc_flag.append("-gencode") -# cc_flag.append("arch=compute_75,code=sm_75") + cc_flag.append("-gencode") cc_flag.append("arch=compute_80,code=sm_80") if bare_metal_version >= Version("11.8"): diff --git a/tests/test_kernels/cuda/test_rmsnorm.py b/tests/test_kernels/cuda/test_rmsnorm.py index 8f748dedefee..3e7da8412a01 100644 --- a/tests/test_kernels/cuda/test_rmsnorm.py +++ b/tests/test_kernels/cuda/test_rmsnorm.py @@ -1,5 +1,7 @@ import os +import pytest import numpy as np +from packaging import version import torch from torch import nn @@ -41,7 +43,7 @@ def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon): ) return out - +@pytest.mark.skipif(not HAS_INFER_CUDA, reason="You need to install llama supported cuda kernels to run this test") def test_rmsnorm(): data = torch.randn((1024, 64), dtype=torch.float16, device="cuda") hg_rms = LlamaRMSNorm(64) @@ -53,5 +55,4 @@ def test_rmsnorm(): assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward" if __name__ == "__main__": - if HAS_INFER_CUDA: - test_rmsnorm() \ No newline at end of file + test_rmsnorm() \ No newline at end of file diff --git a/tests/test_kernels/cuda/test_rotary_embedding.py b/tests/test_kernels/cuda/test_rotary_embedding.py index 178857eae4ad..6445fba7c7d2 100644 --- a/tests/test_kernels/cuda/test_rotary_embedding.py +++ b/tests/test_kernels/cuda/test_rotary_embedding.py @@ -1,4 +1,6 @@ +import pytest from typing import Tuple + import torch import torch.nn as nn import torch.nn.functional as F @@ -135,7 +137,7 @@ def run_rotary_embedding_neox( assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5) assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5) - +@pytest.mark.skipif(not HAS_INFER_CUDA, reason="You need to install llama supported cuda kernels to run this test") def test(): run_rotary_embedding_neox( num_tokens=1024, @@ -147,5 +149,4 @@ def test(): ) if __name__ == "__main__": - if HAS_INFER_CUDA: - test() \ No newline at end of file + test() \ No newline at end of file diff --git a/tests/test_kernels/cuda/test_softmax.py b/tests/test_kernels/cuda/test_softmax.py index 879e8981c85f..8b3b99dcc981 100644 --- a/tests/test_kernels/cuda/test_softmax.py +++ b/tests/test_kernels/cuda/test_softmax.py @@ -1,4 +1,5 @@ import os +import pytest import numpy as np import torch @@ -10,18 +11,18 @@ HAS_INFER_CUDA = False print("please install your cuda ") -if HAS_INFER_CUDA: - def test(): - size = (17, 3, 1024, 256) - data = torch.randn(size = size, device="cuda", dtype=torch.float16) - mask = torch.zeros(size = (17, 1, 1024, 256), device="cuda", dtype=torch.uint8) +@pytest.mark.skipif(not HAS_INFER_CUDA, reason="You need to install llama supported cuda kernels to run this test") +def test(): + size = (17, 3, 1024, 256) + data = torch.randn(size = size, device="cuda", dtype=torch.float16) + mask = torch.zeros(size = (17, 1, 1024, 256), device="cuda", dtype=torch.uint8) - out_cuda = scaled_masked_softmax_forward(data, mask, 1) + out_cuda = scaled_masked_softmax_forward(data, mask, 1) - out_torch = F.softmax(data, dim = -1) + out_torch = F.softmax(data, dim = -1) - check = torch.allclose(out_cuda.cpu(), out_torch.cpu(), rtol=1e-3, atol=1e-3) - assert check is True, "the output from cuda softmax is not matched with output from torch" + check = torch.allclose(out_cuda.cpu(), out_torch.cpu(), rtol=1e-3, atol=1e-3) + assert check is True, "the output from cuda softmax is not matched with output from torch" - if __name__ == "__main__": - test() \ No newline at end of file +if __name__ == "__main__": + test() \ No newline at end of file From 7ba2d61b9d882e4bdc3365f763512244c73b9a6d Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 15 Aug 2023 23:09:10 +0800 Subject: [PATCH 19/31] delete useless files --- colossalai/kernel/cuda_native/bert_padding.py | 132 ------ .../cuda_native/flash_attn_interface.py | 385 ------------------ colossalai/kernel/cuda_native/linear.py | 48 --- 3 files changed, 565 deletions(-) delete mode 100644 colossalai/kernel/cuda_native/bert_padding.py delete mode 100644 colossalai/kernel/cuda_native/flash_attn_interface.py delete mode 100644 colossalai/kernel/cuda_native/linear.py diff --git a/colossalai/kernel/cuda_native/bert_padding.py b/colossalai/kernel/cuda_native/bert_padding.py deleted file mode 100644 index 6826949a2604..000000000000 --- a/colossalai/kernel/cuda_native/bert_padding.py +++ /dev/null @@ -1,132 +0,0 @@ -# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py - -import torch -import torch.nn.functional as F - -from einops import rearrange, repeat - - -class IndexFirstAxis(torch.autograd.Function): - - @staticmethod - def forward(ctx, input, indices): - ctx.save_for_backward(indices) - assert input.ndim >= 2 - ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = other_shape.numel() - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - # return input[indices] - return torch.gather(rearrange(input, 'b ... -> b (...)'), 0, - repeat(indices, 'z -> z d', d=second_dim)).reshape(-1, *other_shape) - - @staticmethod - def backward(ctx, grad_output): - indices, = ctx.saved_tensors - assert grad_output.ndim >= 2 - other_shape = grad_output.shape[1:] - grad_output = rearrange(grad_output, 'b ... -> b (...)') - grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]], - device=grad_output.device, dtype=grad_output.dtype) - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. - # grad_input[indices] = grad_output - grad_input.scatter_(0, repeat(indices, 'z -> z d', d=grad_output.shape[1]), grad_output) - return grad_input.reshape(ctx.first_axis_dim, *other_shape), None - - -index_first_axis = IndexFirstAxis.apply - - -class IndexPutFirstAxis(torch.autograd.Function): - - @staticmethod - def forward(ctx, values, indices, first_axis_dim): - ctx.save_for_backward(indices) - assert indices.ndim == 1 - assert values.ndim >= 2 - output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, - dtype=values.dtype) - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. - output[indices] = values - # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) - return output - - @staticmethod - def backward(ctx, grad_output): - indices, = ctx.saved_tensors - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - grad_values = grad_output[indices] - # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) - return grad_values, None, None - - -index_put_first_axis = IndexPutFirstAxis.apply - - -class IndexFirstAxisResidual(torch.autograd.Function): - - @staticmethod - def forward(ctx, input, indices): - ctx.save_for_backward(indices) - assert input.ndim >= 2 - ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = other_shape.numel() - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - output = input[indices] - # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last - # memory format to channel_first. In other words, input might not be contiguous. - # If we don't detach, Pytorch complains about output being a view and is being modified inplace - return output, input.detach() - - @staticmethod - def backward(ctx, grad_output, grad_residual): - indices, = ctx.saved_tensors - assert grad_output.ndim >= 2 - other_shape = grad_output.shape[1:] - assert grad_residual.shape[1:] == other_shape - grad_input = grad_residual - # grad_input[indices] += grad_output - indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) - indices = indices.expand_as(grad_output) - grad_input.scatter_add_(0, indices, grad_output) - return grad_input.reshape(ctx.first_axis_dim, *other_shape), None - - -index_first_axis_residual = IndexFirstAxisResidual.apply - - -def unpad_input(hidden_states, attention_mask): - """ - Arguments: - hidden_states: (batch, seqlen, ...) - attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. - Return: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. - max_seqlen_in_batch: int - """ - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the - # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim - # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to - # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, - # so we write custom forward and backward to make it a bit faster. - return (index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'), indices), indices, - cu_seqlens, max_seqlen_in_batch) - - -def pad_input(hidden_states, indices, batch, seqlen): - """ - Arguments: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz) - Return: - hidden_states: (batch, seqlen, ...) - """ - dim = hidden_states.shape[-1] - # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) - # output[indices] = hidden_states - output = index_put_first_axis(hidden_states, indices, batch * seqlen) - return rearrange(output, '(b s) ... -> b s ...', b=batch) \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/flash_attn_interface.py b/colossalai/kernel/cuda_native/flash_attn_interface.py deleted file mode 100644 index bf57ea7e9871..000000000000 --- a/colossalai/kernel/cuda_native/flash_attn_interface.py +++ /dev/null @@ -1,385 +0,0 @@ -import torch -import torch.nn as nn - -from einops import rearrange -import col_flash_attn_2_lib as flash_attn_cuda - -def _get_block_size(device, head_dim, is_dropout, is_causal): - # This should match the block sizes in the CUDA kernel - assert head_dim <= 256 - major, minor = torch.cuda.get_device_capability(device) - is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100) - is_sm80 = major == 8 and minor == 0 - is_sm90 = major == 9 and minor == 0 - if head_dim <= 32: - return 128, 128 - if head_dim <= 64: - return (128, 128) if not is_dropout else (128, 64) - elif head_dim <= 96: - return (64, 64) if (is_sm8x and is_causal) else (128, 64) - elif head_dim <= 128: - if is_sm8x: - return (64, 64) if (not is_dropout and is_causal) else (128, 32) - else: - return 128, (64 if not is_dropout else 32) - elif head_dim <= 160: - if is_sm8x: - return (128, 64) if not is_causal else (64, 64) - else: - return 128, 32 - elif head_dim <= 192: - return (128, 64) if not is_dropout else (64, 64) - elif head_dim <= 224: - return (128, 64) if (is_sm80 or is_sm90) else (64, 64) - elif head_dim <= 256: - return (128, 64) if is_sm80 else (64, 64) - - -def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.flash_fwd( - q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None - ) - return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state - - -def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, causal, return_softmax): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x - q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_flash_fwd( - q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, - softmax_scale, False, causal, return_softmax, None - ) - # if out.isnan().any() or softmax_lse.isnan().any(): - # breakpoint() - return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state - - -class FlashAttnQKVPackedFunc(torch.autograd.Function): - - @staticmethod - def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax): - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( - qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], dropout_p, softmax_scale, - causal=causal, return_softmax=return_softmax and dropout_p > 0 - ) - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return out if not return_softmax else (out, softmax_lse, S_dmask) - - -class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): - - @staticmethod - def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax): - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( - qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, - dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 - ) - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) - ctx.dropout_p = dropout_p - ctx.max_seqlen = max_seqlen - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return out if not return_softmax else (out, softmax_lse, S_dmask) - - - -class FlashAttnKVPackedFunc(torch.autograd.Function): - - @staticmethod - def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( - q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal=causal, - return_softmax=return_softmax and dropout_p > 0 - ) - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return out if not return_softmax else (out, softmax_lse, S_dmask) - - - -class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): - - @staticmethod - def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, - softmax_scale, causal, return_softmax): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( - q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 - ) - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, - cu_seqlens_q, cu_seqlens_k, rng_state) - ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return out if not return_softmax else (out, softmax_lse, S_dmask) - - - -class FlashAttnFunc(torch.autograd.Function): - - @staticmethod - def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( - q, k, v, dropout_p, softmax_scale, causal=causal, - return_softmax=return_softmax and dropout_p > 0 - ) - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return out if not return_softmax else (out, softmax_lse, S_dmask) - - -class FlashAttnVarlenFunc(torch.autograd.Function): - - @staticmethod - def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, - softmax_scale, causal, return_softmax): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( - q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0 - ) - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, - cu_seqlens_q, cu_seqlens_k, rng_state) - ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - return out if not return_softmax else (out, softmax_lse, S_dmask) - - - -def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, - return_attn_probs=False): - """dropout_p should be set to 0.0 during evaluation - If Q, K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of Q, K, V. - For multi-query and grouped-query attention (MQA/GQA), please see - flash_attn_kvpacked_func and flash_attn_func. - - Arguments: - qkv: (batch_size, seqlen, 3, nheads, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs) - - -def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=False, - return_attn_probs=False): - """dropout_p should be set to 0.0 during evaluation - If K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of K, V. - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - kv: (batch_size, seqlen, 2, nheads_k, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, return_attn_probs) - - -def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, - return_attn_probs=False): - """dropout_p should be set to 0.0 during evaluation - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - k: (batch_size, seqlen, nheads_k, headdim) - v: (batch_size, seqlen, nheads_k, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs) - - -def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None, - causal=False, return_attn_probs=False): - """dropout_p should be set to 0.0 during evaluation - If Q, K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of Q, K, V. - For multi-query and grouped-query attention (MQA/GQA), please see - flash_attn_varlen_kvpacked_func and flash_attn_varlen_func. - - Arguments: - qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. - cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into qkv. - max_seqlen: int. Maximum sequence length in the batch. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnVarlenQKVPackedFunc.apply( - qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs - ) - - -def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p=0.0, softmax_scale=None, causal=False, - return_attn_probs=False): - """dropout_p should be set to 0.0 during evaluation - If K, V are already stacked into 1 tensor, this function will be faster than - calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation - of the gradients of K, V. - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - Arguments: - q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. - cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into q. - cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into kv. - max_seqlen_q: int. Maximum query sequence length in the batch. - max_seqlen_k: int. Maximum key sequence length in the batch. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnVarlenKVPackedFunc.apply( - q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, causal, return_attn_probs - ) - - -def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p=0.0, softmax_scale=None, causal=False, - return_attn_probs=False): - """dropout_p should be set to 0.0 during evaluation - Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - Arguments: - q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into q. - cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into kv. - max_seqlen_q: int. Maximum query sequence length in the batch. - max_seqlen_k: int. Maximum key sequence length in the batch. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnVarlenFunc.apply( - q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, causal, return_attn_probs - ) \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/linear.py b/colossalai/kernel/cuda_native/linear.py deleted file mode 100644 index 718cbcbad4d8..000000000000 --- a/colossalai/kernel/cuda_native/linear.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -try: - from col_linear_lib import dense_layer_fp32_forward, dense_layer_fp16_forward, batch_dense_layer_fp16_forward - HAS_FLASH_CUDA = True -except: - HAS_FLASH_CUDA = False - print("in order to use flash-attention, make sure you install cuda kernels in op directory") - - -if HAS_FLASH_CUDA: - def linear(data, weight): - data_shape = None - if len(data.shape) > 2: - data_shape = data.shape - data = data.view(-1, data.shape[-1]) - - assert data.dtype == torch.float16, "only fp16 precision supports" - assert len(data.shape) == 2, "the shape must be 2-D" - assert len(weight.shape) == 2, "the shape must be 2-D" - - M, K = data.shape - _, N = weight.shape - - assert K == weight.shape[0], "the shape is not matchted" - - out = torch.empty((M, N), device=data.get_device(), dtype=torch.float16) - dense_layer_fp16_forward(data, weight, out, 99) - if data_shape is not None: - out = out.view(*data_shape[:-1], N) - return out - - - def batch_linear(data, weight, alibi = None, alpha = 1, beta = 0): - """ - it is equivalent to alibi.bmm(data, weight) - only supports float16 - """ - batch_count, M, K = data.shape - _, N = weight.shape - assert data.shape[-1] == weight.shape[0], "the k-dimensions must be matched" - if alibi is None: - out = torch.empty((batch_count, M, N), dtype=torch.float16, device=data.get_device()) - else: - out = alibi - - batch_dense_layer_fp16_forward(data, weight, out, alpha, beta, 99) - return out - \ No newline at end of file From 03e4149e3d10677f448da96b09bddcde0700adad Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 16 Aug 2023 10:59:14 +0800 Subject: [PATCH 20/31] chnage setup --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 6002dd487361..495290f7f2e7 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ THIS_DIR = os.path.dirname(os.path.abspath(__file__)) BUILD_CUDA_EXT = int(os.environ.get('CUDA_EXT', '0')) == 1 IS_NIGHTLY = int(os.environ.get('NIGHTLY', '0')) == 1 -LLAMA_INFER_CUDA = int(os.environ.get("LLAMA_INFER_CUDA", "0")) == 1 +INFER_CUDA_KERNEL = int(os.environ.get("INFER_CUDA_KERNEL", "0")) == 1 # a variable to store the op builder ext_modules = [] @@ -139,7 +139,7 @@ def get_version() -> str: op_name_list = ', '.join(op_names) print(f"[extension] loaded builders for {op_name_list}") -if LLAMA_INFER_CUDA: +if INFER_CUDA_KERNEL: from op_builder.llama_infer_cuda import llama_cuda_submodules for sub_module in llama_cuda_submodules: ext_modules.append(sub_module) From 746c1c991abdc0726309c60513d3bcb895285ca4 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 16 Aug 2023 12:44:00 +0800 Subject: [PATCH 21/31] change intp build class --- .../rmsnorm/layernorm_kernels.cu | 3 - .../rotary_embedding/pos_encoding.cpp | 3 - .../rotary_embedding/pos_encoding_kernels.cu | 3 - .../softmax/fused_softmax.cpp | 46 --- .../softmax/scaled_masked_softmax.h | 338 ------------------ .../softmax/scaled_masked_softmax_cuda.cu | 78 ---- .../softmax/type_shim.h | 20 -- op_builder/__init__.py | 8 +- op_builder/llama_infer_cuda.py | 108 ------ op_builder/rmsnorm.py | 39 ++ op_builder/rotary_embedding.py | 67 ++++ setup.py | 5 - tests/test_kernels/cuda/test_rmsnorm.py | 6 +- .../cuda/test_rotary_embedding.py | 4 +- tests/test_kernels/cuda/test_softmax.py | 28 -- 15 files changed, 119 insertions(+), 637 deletions(-) delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/fused_softmax.cpp delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/scaled_masked_softmax.h delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/scaled_masked_softmax_cuda.cu delete mode 100644 colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/type_shim.h delete mode 100644 op_builder/llama_infer_cuda.py create mode 100644 op_builder/rmsnorm.py create mode 100644 op_builder/rotary_embedding.py delete mode 100644 tests/test_kernels/cuda/test_softmax.py diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu index 7f7889df3ec4..cab3501bac7c 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu @@ -1,6 +1,3 @@ -/*This code from Vllm : https://github.com/vllm-project/vllm - * with minor changes. */ - #include #include diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp index 16749fd52155..565d134cdedf 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp @@ -1,6 +1,3 @@ -/*This code from Vllm : https://github.com/vllm-project/vllm - * with minor changes. */ - #include void rotary_embedding_neox( diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu index 8324f1e66556..1f0f8968619b 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu @@ -1,6 +1,3 @@ -/*This code from Vllm : https://github.com/vllm-project/vllm - * with minor changes. */ - #include #include diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/fused_softmax.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/fused_softmax.cpp deleted file mode 100644 index ffc68e6c731c..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/fused_softmax.cpp +++ /dev/null @@ -1,46 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor); - -torch::Tensor fwd( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) { - AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); - - return fwd_cuda(input, mask, scale_factor); -} - - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("scaled_masked_softmax_forward", - &fwd, - "self-multihead attention scaled masked softmax(forward)"); -} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/scaled_masked_softmax.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/scaled_masked_softmax.h deleted file mode 100644 index d923ade203dc..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/scaled_masked_softmax.h +++ /dev/null @@ -1,338 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -namespace { - -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } - -int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } - } -} - -/* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Explicit masking - */ -template -__global__ void scaled_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const acc_t scale, - int micro_batch_size, - int element_count, - int pad_batches) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; - int pad_first_batch = 0; - if (pad_batches != 1) { // bert style - pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; - } else { // gpt2 style - pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - } - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - int itr_idx = i*element_count+it*WARP_SIZE; - copy_vector(temp_data, src + itr_idx); - copy_vector(temp_mask, mask + itr_idx); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (temp_mask[element] != 1) { - elements[i][it + element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -10000.0; - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } - } - - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - warp_reduce(max_value); - - // compute scale value to account for full mask - acc_t scale_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0; - } - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] * scale_value[i]/ sum[i]; - } - copy_vector(dst + i * element_count + it * WARP_SIZE, out); - } else { - break; - } - } - } -} - - -} // end of anonymous namespace - -int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - constexpr int threads_per_block = 128; - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - - return batches_per_block; -} - -template -void dispatch_scaled_masked_softmax_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const input_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads, - int pad_batches) -{ - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 8192 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); - dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 1: // 2 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 2: // 4 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 3: // 8 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 4: // 16 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 5: // 32 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 6: // 64 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 7: // 128 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 8: // 256 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 9: // 512 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 10: // 1024 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 11: // 2048 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 12: // 4096 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 13: // 8192 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - default: - break; - } - } -} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/scaled_masked_softmax_cuda.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/scaled_masked_softmax_cuda.cu deleted file mode 100644 index de3547671c5f..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/scaled_masked_softmax_cuda.cu +++ /dev/null @@ -1,78 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include "scaled_masked_softmax.h" -#include "type_shim.h" - - - -int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ - return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); -} - - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) -{ - // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] - const int batches = input.size(0); - const int pad_batches = mask.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); - TORCH_INTERNAL_ASSERT(key_seq_len <= 8192); - TORCH_INTERNAL_ASSERT(query_seq_len > 1); - TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); - TORCH_INTERNAL_ASSERT(mask.size(1) == 1); - TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); - TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); - - // Output - auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); - - // Softmax Intermediate Result Ptr - void* input_ptr = static_cast(input.data_ptr()); - void* mask_ptr = static_cast(mask.data_ptr()); - void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - DISPATCH_HALF_AND_BFLOAT( - input.scalar_type(), - "dispatch_scaled_masked_softmax_forward", - dispatch_scaled_masked_softmax_forward( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), - reinterpret_cast(mask_ptr), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads, - pad_batches - ); - ); - return softmax_results; -} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/type_shim.h b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/type_shim.h deleted file mode 100644 index 815ec7ec8896..000000000000 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/type_shim.h +++ /dev/null @@ -1,20 +0,0 @@ -#include - -#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ -switch(TYPE) \ -{ \ -case at::ScalarType::Half: \ - { \ -using scalar_t = at::Half; \ -__VA_ARGS__; \ -break; \ - } \ -case at::ScalarType::BFloat16: \ - { \ -using scalar_t = at::BFloat16; \ -__VA_ARGS__; \ -break; \ - } \ -default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ -} diff --git a/op_builder/__init__.py b/op_builder/__init__.py index 5ae7223b8c69..85d26a258d23 100644 --- a/op_builder/__init__.py +++ b/op_builder/__init__.py @@ -5,6 +5,8 @@ from .multi_head_attn import MultiHeadAttnBuilder from .scaled_masked_softmax import ScaledMaskedSoftmaxBuilder from .scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder +from .rmsnorm import RMSNORMBuilder +from .rotary_embedding import ROTARYEMBEDDINGBuilder ALL_OPS = { 'cpu_adam': CPUAdamBuilder, @@ -14,10 +16,14 @@ 'scaled_masked_softmax': ScaledMaskedSoftmaxBuilder, 'scaled_upper_triangle_masked_softmax': ScaledUpperTrainglemaskedSoftmaxBuilder, 'layernorm': LayerNormBuilder, + 'rmsnorm': RMSNORMBuilder, + 'rotary_embedding': ROTARYEMBEDDINGBuilder, } __all__ = [ 'ALL_OPS', 'CPUAdamBuilder', 'FusedOptimBuilder', 'MultiHeadAttnBuilder', 'ScaledMaskedSoftmaxBuilder', 'ScaledUpperTrainglemaskedSoftmaxBuilder', 'MOEBuilder', 'MultiTensorSGDBuilder', 'MultiTensorAdamBuilder', - 'MultiTensorLambBuilder', 'MultiTensorScaleBuilder', 'MultiTensorL2NormBuilder' + 'MultiTensorLambBuilder', 'MultiTensorScaleBuilder', 'MultiTensorL2NormBuilder', + 'RMSNORMBuilder', + 'ROTARYEMBEDDINGBuilder', ] diff --git a/op_builder/llama_infer_cuda.py b/op_builder/llama_infer_cuda.py deleted file mode 100644 index 1fd4f432eee9..000000000000 --- a/op_builder/llama_infer_cuda.py +++ /dev/null @@ -1,108 +0,0 @@ -import sys -import warnings -import os -import re -import ast -from pathlib import Path -from packaging.version import parse, Version - -from setuptools import setup, find_packages -import subprocess - -import torch -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME -from torch.utils import cpp_extension - - -# ninja build does not work unless include_dirs are abs path -this_dir = os.getcwd() - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.2"): - return nvcc_extra_args + ["--threads", "4"] - return nvcc_extra_args - - -if not torch.cuda.is_available(): - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.8"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" - elif bare_metal_version >= Version("11.1"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" - elif bare_metal_version == Version("11.0"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" - - -# Check, if CUDA11 is installed for compute capability 8.0 -cc_flag = [] -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) -if bare_metal_version < Version("11.0"): - raise RuntimeError("FlashAttention is only supported on CUDA 11 and above") - -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") -if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - -llama_cuda_submodules = [ - CUDAExtension( - name='col_fused_softmax_lib', - sources=[ - this_dir + '/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/fused_softmax.cpp', - this_dir + '/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/softmax/scaled_masked_softmax_cuda.cu' - ], - extra_compile_args={ - 'cxx': ['-O3',], - 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) - } - ), - - CUDAExtension( - name="col_pos_encoding_ops", - sources=[ - this_dir + "/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp", - this_dir + "/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu" - ], - extra_compile_args={ - 'cxx': ['-O3',], - 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) - }, - ), - - CUDAExtension( - name="col_rms_norm_ops", - sources=[ - this_dir + "/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm.cpp", - this_dir + "/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu" - ], - extra_compile_args={ - 'cxx': ['-O3',], - 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) - }, - include_dirs=[ - this_dir + '/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm', - ], - ), - - ] - - - - - - - diff --git a/op_builder/rmsnorm.py b/op_builder/rmsnorm.py new file mode 100644 index 000000000000..27d1d879d6d6 --- /dev/null +++ b/op_builder/rmsnorm.py @@ -0,0 +1,39 @@ +import os + +from .builder import Builder +from .utils import append_nvcc_threads, get_cuda_cc_flag + +class RMSNORMBuilder(Builder): + + NAME = "rmsnorm" + PREBUILT_IMPORT_PATH = "colossalai._C.rmsnorm" + + def __init__(self): + super().__init__(name=RMSNORMBuilder.NAME, + prebuilt_import_path=RMSNORMBuilder.PREBUILT_IMPORT_PATH) + + + def include_dirs(self): + ret = [self.csrc_abs_path("attention_infer_kernels/rmsnorm"), self.get_cuda_home_include()] + return ret + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) for fname in [ + 'attention_infer_kernels/rmsnorm/layernorm_kernels.cu', + 'attention_infer_kernels/rmsnorm/layernorm.cpp' + ] + ] + return ret + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', + '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + ] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/op_builder/rotary_embedding.py b/op_builder/rotary_embedding.py new file mode 100644 index 000000000000..5beb4f56d617 --- /dev/null +++ b/op_builder/rotary_embedding.py @@ -0,0 +1,67 @@ +import os +from packaging.version import parse, Version +from setuptools import setup, find_packages +import subprocess + +from .builder import Builder +from .utils import append_nvcc_threads, get_cuda_cc_flag + +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME + +def add_cc_flags(): + def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + # Check, if CUDA11 is installed for compute capability 8.0 + cc_flag = [] + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + print(bare_metal_version) + if bare_metal_version < Version("11.0"): + raise RuntimeError("FlashAttention is only supported on CUDA 11 and above") + + cc_flag.append("-gencode") + cc_flag.append("arch=compute_80,code=sm_80") + if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + + return cc_flag + +class ROTARYEMBEDDINGBuilder(Builder): + + NAME = "rotary_embedding" + PREBUILT_IMPORT_PATH = "colossalai._C.rotary_embedding" + + def __init__(self): + super().__init__(name=ROTARYEMBEDDINGBuilder.NAME, + prebuilt_import_path=ROTARYEMBEDDINGBuilder.PREBUILT_IMPORT_PATH) + + + def include_dirs(self): + ret = [self.csrc_abs_path("attention_infer_kernels/rotary_embedding"), self.get_cuda_home_include()] + return ret + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) for fname in [ + 'attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu', + 'attention_infer_kernels/rotary_embedding/pos_encoding.cpp' + ] + ] + return ret + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', + '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + ] + ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) + add_cc_flags() diff --git a/setup.py b/setup.py index 495290f7f2e7..4e530909fea0 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,6 @@ THIS_DIR = os.path.dirname(os.path.abspath(__file__)) BUILD_CUDA_EXT = int(os.environ.get('CUDA_EXT', '0')) == 1 IS_NIGHTLY = int(os.environ.get('NIGHTLY', '0')) == 1 -INFER_CUDA_KERNEL = int(os.environ.get("INFER_CUDA_KERNEL", "0")) == 1 # a variable to store the op builder ext_modules = [] @@ -139,10 +138,6 @@ def get_version() -> str: op_name_list = ', '.join(op_names) print(f"[extension] loaded builders for {op_name_list}") -if INFER_CUDA_KERNEL: - from op_builder.llama_infer_cuda import llama_cuda_submodules - for sub_module in llama_cuda_submodules: - ext_modules.append(sub_module) # always put not nightly branch as the if branch # otherwise github will treat colossalai-nightly as the project name diff --git a/tests/test_kernels/cuda/test_rmsnorm.py b/tests/test_kernels/cuda/test_rmsnorm.py index 3e7da8412a01..8d3ee4eb8de5 100644 --- a/tests/test_kernels/cuda/test_rmsnorm.py +++ b/tests/test_kernels/cuda/test_rmsnorm.py @@ -7,7 +7,9 @@ from torch import nn from torch.nn import functional as F try: - from col_rms_norm_ops import rms_norm + from colossalai.kernel.op_builder import RMSNORMBuilder + rmsnorm = RMSNORMBuilder().load() + rms_norm = rmsnorm.rms_norm HAS_INFER_CUDA = True except: HAS_INFER_CUDA = False @@ -30,8 +32,6 @@ def forward(self, hidden_states): hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) - - def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon): x = hidden_states out = torch.empty_like(x) diff --git a/tests/test_kernels/cuda/test_rotary_embedding.py b/tests/test_kernels/cuda/test_rotary_embedding.py index 6445fba7c7d2..be43a1061437 100644 --- a/tests/test_kernels/cuda/test_rotary_embedding.py +++ b/tests/test_kernels/cuda/test_rotary_embedding.py @@ -7,7 +7,9 @@ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, rotate_half try: - from col_pos_encoding_ops import rotary_embedding_neox + from colossalai.kernel.op_builder import ROTARYEMBEDDINGBuilder + rotary_embedding = ROTARYEMBEDDINGBuilder().load() + rotary_embedding_neox = rotary_embedding.rotary_embedding_neox HAS_INFER_CUDA = True except: HAS_INFER_CUDA = False diff --git a/tests/test_kernels/cuda/test_softmax.py b/tests/test_kernels/cuda/test_softmax.py deleted file mode 100644 index 8b3b99dcc981..000000000000 --- a/tests/test_kernels/cuda/test_softmax.py +++ /dev/null @@ -1,28 +0,0 @@ -import os -import pytest -import numpy as np - -import torch -from torch.nn import functional as F -try: - from col_fused_softmax_lib import scaled_masked_softmax_forward - HAS_INFER_CUDA = True -except: - HAS_INFER_CUDA = False - print("please install your cuda ") - -@pytest.mark.skipif(not HAS_INFER_CUDA, reason="You need to install llama supported cuda kernels to run this test") -def test(): - size = (17, 3, 1024, 256) - data = torch.randn(size = size, device="cuda", dtype=torch.float16) - mask = torch.zeros(size = (17, 1, 1024, 256), device="cuda", dtype=torch.uint8) - - out_cuda = scaled_masked_softmax_forward(data, mask, 1) - - out_torch = F.softmax(data, dim = -1) - - check = torch.allclose(out_cuda.cpu(), out_torch.cpu(), rtol=1e-3, atol=1e-3) - assert check is True, "the output from cuda softmax is not matched with output from torch" - -if __name__ == "__main__": - test() \ No newline at end of file From b1a2c19229eb17742d62366f6ba5c8d3712a2137 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 16 Aug 2023 14:14:28 +0800 Subject: [PATCH 22/31] added info --- tests/test_kernels/cuda/test_rmsnorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_kernels/cuda/test_rmsnorm.py b/tests/test_kernels/cuda/test_rmsnorm.py index 8d3ee4eb8de5..83aaa27d508f 100644 --- a/tests/test_kernels/cuda/test_rmsnorm.py +++ b/tests/test_kernels/cuda/test_rmsnorm.py @@ -51,7 +51,7 @@ def test_rmsnorm(): out_torch = hg_rms(data) out_cuda = cuda_rmsnorm_forward(data, hg_rms.weight.data, hg_rms.variance_epsilon) - check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-3) + check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-5) assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward" if __name__ == "__main__": From d09a81a34dbcdd9879ab4d978f0e87d30e5cdda4 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 16 Aug 2023 15:06:44 +0800 Subject: [PATCH 23/31] added lisense --- .../csrc/attention_infer_kernels/rmsnorm/layernorm.cpp | 3 +++ .../csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu | 3 +++ .../attention_infer_kernels/rotary_embedding/pos_encoding.cpp | 3 +++ .../rotary_embedding/pos_encoding_kernels.cu | 3 +++ 4 files changed, 12 insertions(+) diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm.cpp index 749ca5f92154..bd5e269243ee 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm.cpp +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm.cpp @@ -1,3 +1,6 @@ +/* ---------------- LICENSE FOR Colossal-AI ---------------- +Copyright 2021- HPC-AI Technology Inc. +*/ #include void rms_norm( diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu index cab3501bac7c..e386c292f817 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu @@ -1,3 +1,6 @@ +/* ---------------- LICENSE FOR Colossal-AI ---------------- +Copyright 2021- HPC-AI Technology Inc. +*/ #include #include diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp index 565d134cdedf..133492b2410f 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp @@ -1,3 +1,6 @@ +/* ---------------- LICENSE FOR Colossal-AI ---------------- +Copyright 2021- HPC-AI Technology Inc. +*/ #include void rotary_embedding_neox( diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu index 1f0f8968619b..fce01794f3a3 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu @@ -1,3 +1,6 @@ +/* ---------------- LICENSE FOR Colossal-AI ---------------- +Copyright 2021- HPC-AI Technology Inc. +*/ #include #include From 0b0b28cb24108e4ae7b893df952a1a92a3b3abbb Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 16 Aug 2023 15:12:52 +0800 Subject: [PATCH 24/31] added lisense --- .../attention_infer_kernels/rmsnorm/layernorm.cpp | 15 ++++++++++++++- .../rmsnorm/layernorm_kernels.cu | 15 ++++++++++++++- .../rotary_embedding/pos_encoding.cpp | 15 ++++++++++++++- .../rotary_embedding/pos_encoding_kernels.cu | 15 ++++++++++++++- tests/test_kernels/cuda/test_rmsnorm.py | 2 ++ tests/test_kernels/cuda/test_rotary_embedding.py | 2 ++ 6 files changed, 60 insertions(+), 4 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm.cpp index bd5e269243ee..65cd73d366f4 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm.cpp +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm.cpp @@ -1,5 +1,18 @@ /* ---------------- LICENSE FOR Colossal-AI ---------------- -Copyright 2021- HPC-AI Technology Inc. +# coding=utf-8 +# Copyright 2021- HPC-AI Technology Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. */ #include diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu index e386c292f817..2daaa18e0110 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu @@ -1,5 +1,18 @@ /* ---------------- LICENSE FOR Colossal-AI ---------------- -Copyright 2021- HPC-AI Technology Inc. +# coding=utf-8 +# Copyright 2021- HPC-AI Technology Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. */ #include #include diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp index 133492b2410f..f106773cef34 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp @@ -1,5 +1,18 @@ /* ---------------- LICENSE FOR Colossal-AI ---------------- -Copyright 2021- HPC-AI Technology Inc. +# coding=utf-8 +# Copyright 2021- HPC-AI Technology Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. */ #include diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu index fce01794f3a3..6db72ad0c30c 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu @@ -1,5 +1,18 @@ /* ---------------- LICENSE FOR Colossal-AI ---------------- -Copyright 2021- HPC-AI Technology Inc. +# coding=utf-8 +# Copyright 2021- HPC-AI Technology Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. */ #include #include diff --git a/tests/test_kernels/cuda/test_rmsnorm.py b/tests/test_kernels/cuda/test_rmsnorm.py index 83aaa27d508f..a09f96b56816 100644 --- a/tests/test_kernels/cuda/test_rmsnorm.py +++ b/tests/test_kernels/cuda/test_rmsnorm.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- import os import pytest import numpy as np diff --git a/tests/test_kernels/cuda/test_rotary_embedding.py b/tests/test_kernels/cuda/test_rotary_embedding.py index be43a1061437..849eefc59179 100644 --- a/tests/test_kernels/cuda/test_rotary_embedding.py +++ b/tests/test_kernels/cuda/test_rotary_embedding.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- import pytest from typing import Tuple From 8dd322f95aecec5193e5dcda6e43b30128b0e55e Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 16 Aug 2023 15:35:16 +0800 Subject: [PATCH 25/31] added lisense --- LICENSE | 18 ++++++++++++++++++ .../rmsnorm/layernorm.cpp | 18 +++--------------- .../rmsnorm/layernorm_kernels.cu | 18 +++--------------- .../rotary_embedding/pos_encoding.cpp | 18 +++--------------- .../rotary_embedding/pos_encoding_kernels.cu | 18 +++--------------- 5 files changed, 30 insertions(+), 60 deletions(-) diff --git a/LICENSE b/LICENSE index c7a5bb16880e..3f2167477589 100644 --- a/LICENSE +++ b/LICENSE @@ -396,3 +396,21 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved. CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ---------------- LICENSE FOR VLLM TEAM ---------------- + + from VLLM TEAM: + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://github.com/vllm-project/vllm/blob/main/LICENSE + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm.cpp index 65cd73d366f4..30d236dee895 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm.cpp +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm.cpp @@ -1,18 +1,6 @@ -/* ---------------- LICENSE FOR Colossal-AI ---------------- -# coding=utf-8 -# Copyright 2021- HPC-AI Technology Inc. team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +/* Copyright 2021 The Colossal-AI Team + Copyright (c) 2023, The vLLM team. + This file is adapted from vllm TEAM: https://github.com/vllm-project/vllm/blob/main/csrc/layernorm.cpp */ #include diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu index 2daaa18e0110..1f6bff6a13d1 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu @@ -1,18 +1,6 @@ -/* ---------------- LICENSE FOR Colossal-AI ---------------- -# coding=utf-8 -# Copyright 2021- HPC-AI Technology Inc. team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +/* Copyright 2021 The Colossal-AI Team + Copyright (c) 2023, The vLLM team. + This file is adapted from vllm TEAM: https://github.com/vllm-project/vllm/blob/main/csrc/layernorm_kernels.cu */ #include #include diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp index f106773cef34..21214d764d36 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp @@ -1,18 +1,6 @@ -/* ---------------- LICENSE FOR Colossal-AI ---------------- -# coding=utf-8 -# Copyright 2021- HPC-AI Technology Inc. team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +/* Copyright 2021 The Colossal-AI Team + Copyright (c) 2023, The vLLM team. + This file is adapted from vllm TEAM: https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding.cpp */ #include diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu index 6db72ad0c30c..3b0767e23938 100644 --- a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu @@ -1,18 +1,6 @@ -/* ---------------- LICENSE FOR Colossal-AI ---------------- -# coding=utf-8 -# Copyright 2021- HPC-AI Technology Inc. team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +/* Copyright 2021 The Colossal-AI Team + Copyright (c) 2023, The vLLM team. + This file is adapted from vllm TEAM: https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu */ #include #include From 1d2294ec1e6aa86fb9bfd37a8618c35033c43ff7 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 16 Aug 2023 15:42:02 +0800 Subject: [PATCH 26/31] change name --- tests/test_kernels/cuda/test_rotary_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_kernels/cuda/test_rotary_embedding.py b/tests/test_kernels/cuda/test_rotary_embedding.py index 849eefc59179..e48d252d0add 100644 --- a/tests/test_kernels/cuda/test_rotary_embedding.py +++ b/tests/test_kernels/cuda/test_rotary_embedding.py @@ -142,7 +142,7 @@ def run_rotary_embedding_neox( assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5) @pytest.mark.skipif(not HAS_INFER_CUDA, reason="You need to install llama supported cuda kernels to run this test") -def test(): +def test_rotary_embedding(): run_rotary_embedding_neox( num_tokens=1024, num_heads=8, @@ -153,4 +153,4 @@ def test(): ) if __name__ == "__main__": - test() \ No newline at end of file + test_rotary_embedding() \ No newline at end of file From 1b53af9df75f5e885ec2de44facfae5800f8221c Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 16 Aug 2023 15:51:54 +0800 Subject: [PATCH 27/31] change flash version --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 65eecce2c34f..47c4dbc6bf9c 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -10,5 +10,5 @@ contexttimer ninja torch>=1.11 safetensors -flash_attn>=2.0 +flash_attn>=2.0.5 einops From 9313374c6174bfd019b6ef665e6e1d9a7460de2c Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 16 Aug 2023 16:09:04 +0800 Subject: [PATCH 28/31] change req --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 47c4dbc6bf9c..cd7c988e9029 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -10,5 +10,5 @@ contexttimer ninja torch>=1.11 safetensors -flash_attn>=2.0.5 +flash_attn==2.0.5 einops From 90859c68c03481caa79933169101f420021372ab Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 16 Aug 2023 19:26:35 +0800 Subject: [PATCH 29/31] added it --- op_builder/rotary_embedding.py | 1 + 1 file changed, 1 insertion(+) diff --git a/op_builder/rotary_embedding.py b/op_builder/rotary_embedding.py index 5beb4f56d617..3633becbc054 100644 --- a/op_builder/rotary_embedding.py +++ b/op_builder/rotary_embedding.py @@ -8,6 +8,7 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME + def add_cc_flags(): def get_cuda_bare_metal_version(cuda_dir): raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) From b921fa2e4bef21a93ea57ceaa0234031d34d4809 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 17 Aug 2023 14:40:11 +0800 Subject: [PATCH 30/31] modify req --- requirements/requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 68d278c889e9..9aa5f2822e40 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -10,5 +10,4 @@ contexttimer ninja torch>=1.12 safetensors -flash_attn==2.0.5 einops From e4341de1ddb90697f23a6a191eb9ee6a9f958680 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 17 Aug 2023 21:34:47 +0800 Subject: [PATCH 31/31] added _vllm_rmsnorm_forward and fix llama forward --- colossalai/shardformer/modeling/llama.py | 44 ++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index f1d2998bbee4..ae3d8e207881 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -7,7 +7,7 @@ CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) -from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaRMSNorm from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -391,9 +391,17 @@ def llama_for_sequence_classification_forward( def get_llama_flash_attention_forward(): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + try: + from vllm import pos_encoding_ops + rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox + HAS_VLLM_KERNERL = True + except: + print("fall back to original rotary_embedding_neox of huggingface") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + HAS_VLLM_KERNERL = False + def forward( self: LlamaAttention, hidden_states: torch.Tensor, @@ -415,7 +423,12 @@ def forward( kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if HAS_VLLM_KERNERL: + cos_sin_cache = torch.cat((cos, sin), dim=-1) + rotary_embedding_neox(position_ids, query_states, key_states, self.head_dim, cos_sin_cache) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention @@ -450,3 +463,28 @@ def forward( return attn_output, None, past_key_value return forward + + +def get_llama_vllm_rmsnorm_forward() + try: + from vllm import layernorm_ops + rms_norm = layernorm_ops.rms_norm + HAS_VLLM_KERNERL = True + except: + print("please install vllm kernels to install rmsnorm") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + HAS_VLLM_KERNERL = False + + def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + x = hidden_states + out = torch.empty_like(x) + rms_norm( + out, + x, + self.weight.data, + self.variance_epsilon, + ) + + return out + + return _vllm_rmsnorm_forward