|
| 1 | +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include <cuda_runtime.h> |
| 16 | +#include <paddle/extension.h> |
| 17 | +#include <algorithm> |
| 18 | +#include "helper.h" |
| 19 | + |
| 20 | +#define THREADS_PER_BLOCK 128 |
| 21 | + |
| 22 | +template <typename T> |
| 23 | +struct Converter; |
| 24 | + |
| 25 | +template <> |
| 26 | +struct Converter<__half> { |
| 27 | + // __half -> float |
| 28 | + __device__ static float to_float(__half val) { return __half2float(val); } |
| 29 | + // float -> __half |
| 30 | + __device__ static __half from_float(float val) { |
| 31 | + return __float2half_rn(val); |
| 32 | + } |
| 33 | + // int -> __half |
| 34 | + __device__ static __half from_int(float val) { return __int2half_rn(val); } |
| 35 | +}; |
| 36 | + |
| 37 | +template <> |
| 38 | +struct Converter<__nv_bfloat16> { |
| 39 | + // __nv_bfloat16 -> float |
| 40 | + __device__ static float to_float(__nv_bfloat16 val) { |
| 41 | + return __bfloat162float(val); |
| 42 | + } |
| 43 | + // float -> __nv_bfloat16 |
| 44 | + __device__ static __nv_bfloat16 from_float(float val) { |
| 45 | + return __float2bfloat16_rn(val); |
| 46 | + } |
| 47 | + // int -> __nv_bfloat16 |
| 48 | + __device__ static __nv_bfloat16 from_int(int val) { |
| 49 | + return __int2bfloat16_rn(val); |
| 50 | + } |
| 51 | +}; |
| 52 | + |
| 53 | +template <typename T> |
| 54 | +__device__ void RotateQKVec4(const T* qk_ptr, |
| 55 | + const T* rot_cos_ptr, |
| 56 | + const T* rot_sin_ptr, |
| 57 | + const int head_num, |
| 58 | + const int base_idx, |
| 59 | + const int rot_base_idx, |
| 60 | + T* out) { |
| 61 | + using VecT = AlignedVector<T, 4>; |
| 62 | + |
| 63 | + VecT qk_vec; |
| 64 | + Load(qk_ptr + base_idx, &qk_vec); |
| 65 | + VecT rot_half_vec = {-qk_vec[1], qk_vec[0], -qk_vec[3], qk_vec[2]}; |
| 66 | + VecT cos_vec, sin_vec; |
| 67 | + Load(rot_cos_ptr + rot_base_idx, &cos_vec); |
| 68 | + Load(rot_sin_ptr + rot_base_idx, &sin_vec); |
| 69 | +#pragma unroll |
| 70 | + for (int i = 0; i < 4; ++i) { |
| 71 | + *(out + base_idx + i) = |
| 72 | + qk_vec[i] * cos_vec[i] + rot_half_vec[i] * sin_vec[i]; |
| 73 | + } |
| 74 | +} |
| 75 | + |
| 76 | +template <typename T> |
| 77 | +__device__ void RotateQKVec4(const T* qk_ptr, |
| 78 | + const float* rot_cos_ptr, |
| 79 | + const float* rot_sin_ptr, |
| 80 | + const int head_num, |
| 81 | + const int base_idx, |
| 82 | + const int rot_base_idx, |
| 83 | + T* out) { |
| 84 | + using VecT = AlignedVector<T, 4>; |
| 85 | + using VecF = AlignedVector<float, 4>; |
| 86 | + auto to_float = [] __device__(T val) -> float { |
| 87 | + return Converter<T>::to_float(val); |
| 88 | + }; |
| 89 | + auto from_float = [] __device__(float val) -> T { |
| 90 | + return Converter<T>::from_float(val); |
| 91 | + }; |
| 92 | + |
| 93 | + VecT qk_vec; |
| 94 | + Load(qk_ptr + base_idx, &qk_vec); |
| 95 | + VecF rot_half_vec = {-to_float(qk_vec[1]), |
| 96 | + to_float(qk_vec[0]), |
| 97 | + -to_float(qk_vec[3]), |
| 98 | + to_float(qk_vec[2])}; |
| 99 | + VecF cos_vec, sin_vec; |
| 100 | + Load(rot_cos_ptr + rot_base_idx, &cos_vec); |
| 101 | + Load(rot_sin_ptr + rot_base_idx, &sin_vec); |
| 102 | +#pragma unroll |
| 103 | + for (int i = 0; i < 4; ++i) { |
| 104 | + *(out + base_idx + i) = from_float(to_float(qk_vec[i]) * cos_vec[i] + |
| 105 | + rot_half_vec[i] * sin_vec[i]); |
| 106 | + } |
| 107 | +} |
| 108 | + |
| 109 | +// qk and rope have a same type |
| 110 | +template <typename T> |
| 111 | +__global__ void DispatchApplyRopeVec4Kernel(const T* q, |
| 112 | + const T* k, |
| 113 | + const T* rot_cos, |
| 114 | + const T* rot_sin, |
| 115 | + const int q_num_elements, |
| 116 | + const int k_num_elements, |
| 117 | + const int q_head_num, |
| 118 | + const int k_head_num, |
| 119 | + const int head_dim, |
| 120 | + T* q_out, |
| 121 | + T* k_out) { |
| 122 | + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; |
| 123 | + int head_dim_idx = idx % head_dim; |
| 124 | + |
| 125 | + if (idx < q_num_elements) { |
| 126 | + int rot_idx = idx / (q_head_num * head_dim) * head_dim + head_dim_idx; |
| 127 | + RotateQKVec4(q, rot_cos, rot_sin, q_head_num, idx, rot_idx, q_out); |
| 128 | + } |
| 129 | + |
| 130 | + if (idx < k_num_elements) { |
| 131 | + int rot_idx = idx / (k_head_num * head_dim) * head_dim + head_dim_idx; |
| 132 | + RotateQKVec4(k, rot_cos, rot_sin, k_head_num, idx, rot_idx, k_out); |
| 133 | + } |
| 134 | +} |
| 135 | + |
| 136 | +// rope dtype is float32 |
| 137 | +template <typename T> |
| 138 | +__global__ void DispatchApplyRopeVec4Kernel(const T* q, |
| 139 | + const T* k, |
| 140 | + const float* rot_cos, |
| 141 | + const float* rot_sin, |
| 142 | + const int q_num_elements, |
| 143 | + const int k_num_elements, |
| 144 | + const int q_head_num, |
| 145 | + const int k_head_num, |
| 146 | + const int head_dim, |
| 147 | + T* q_out, |
| 148 | + T* k_out) { |
| 149 | + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; |
| 150 | + int head_dim_idx = idx % head_dim; |
| 151 | + |
| 152 | + if (idx < q_num_elements) { |
| 153 | + int rot_idx = idx / (q_head_num * head_dim) * head_dim + head_dim_idx; |
| 154 | + RotateQKVec4(q, rot_cos, rot_sin, q_head_num, idx, rot_idx, q_out); |
| 155 | + } |
| 156 | + |
| 157 | + if (idx < k_num_elements) { |
| 158 | + int rot_idx = idx / (k_head_num * head_dim) * head_dim + head_dim_idx; |
| 159 | + RotateQKVec4(k, rot_cos, rot_sin, k_head_num, idx, rot_idx, k_out); |
| 160 | + } |
| 161 | +} |
| 162 | + |
| 163 | +template <paddle::DataType D> |
| 164 | +void ApplyRopeKernel(const paddle::Tensor& q, |
| 165 | + const paddle::Tensor& k, |
| 166 | + const paddle::Tensor& rot_cos, |
| 167 | + const paddle::Tensor& rot_sin, |
| 168 | + paddle::Tensor& q_out, |
| 169 | + paddle::Tensor& k_out) { |
| 170 | + typedef PDTraits<D> traits_; |
| 171 | + typedef typename traits_::DataType DataType_; |
| 172 | + typedef typename traits_::data_t data_t; |
| 173 | + |
| 174 | + const auto q_num_elements = q.numel(); |
| 175 | + const auto k_num_elements = k.numel(); |
| 176 | + const auto q_shape = q.shape(); |
| 177 | + const auto k_shape = k.shape(); |
| 178 | + const auto dims = q_shape.size(); |
| 179 | + const auto q_head_num = q_shape[dims - 2]; |
| 180 | + const auto k_head_num = k_shape[dims - 2]; |
| 181 | + const auto head_dim = q_shape.back(); |
| 182 | + int block_num = |
| 183 | + (std::max(q_num_elements, k_num_elements) + (THREADS_PER_BLOCK * 4) - 1) / |
| 184 | + (THREADS_PER_BLOCK * 4); |
| 185 | + auto stream = q.stream(); |
| 186 | + |
| 187 | + if (q.dtype() == rot_cos.dtype()) { |
| 188 | + DispatchApplyRopeVec4Kernel<DataType_> |
| 189 | + <<<block_num, THREADS_PER_BLOCK, 0, stream>>>( |
| 190 | + reinterpret_cast<const DataType_*>(q.data<data_t>()), |
| 191 | + reinterpret_cast<const DataType_*>(k.data<data_t>()), |
| 192 | + reinterpret_cast<const DataType_*>(rot_cos.data<data_t>()), |
| 193 | + reinterpret_cast<const DataType_*>(rot_sin.data<data_t>()), |
| 194 | + q_num_elements, |
| 195 | + k_num_elements, |
| 196 | + q_head_num, |
| 197 | + k_head_num, |
| 198 | + head_dim, |
| 199 | + reinterpret_cast<DataType_*>(q_out.data<data_t>()), |
| 200 | + reinterpret_cast<DataType_*>(k_out.data<data_t>())); |
| 201 | + } else if (rot_cos.dtype() == paddle::DataType::FLOAT32) { |
| 202 | + DispatchApplyRopeVec4Kernel<DataType_> |
| 203 | + <<<block_num, THREADS_PER_BLOCK, 0, stream>>>( |
| 204 | + reinterpret_cast<const DataType_*>(q.data<data_t>()), |
| 205 | + reinterpret_cast<const DataType_*>(k.data<data_t>()), |
| 206 | + reinterpret_cast<const float*>(rot_cos.data<float>()), |
| 207 | + reinterpret_cast<const float*>(rot_sin.data<float>()), |
| 208 | + q_num_elements, |
| 209 | + k_num_elements, |
| 210 | + q_head_num, |
| 211 | + k_head_num, |
| 212 | + head_dim, |
| 213 | + reinterpret_cast<DataType_*>(q_out.data<data_t>()), |
| 214 | + reinterpret_cast<DataType_*>(k_out.data<data_t>())); |
| 215 | + } else { |
| 216 | + PD_THROW("Unsupported qk dtype and rope dtype."); |
| 217 | + } |
| 218 | +} |
| 219 | + |
| 220 | +std::vector<paddle::Tensor> ApplyRope(const paddle::Tensor& q, |
| 221 | + const paddle::Tensor& k, |
| 222 | + const paddle::Tensor& rot_cos, |
| 223 | + const paddle::Tensor& rot_sin) { |
| 224 | + auto q_shape = q.shape(); |
| 225 | + auto cos_shape = rot_cos.shape(); |
| 226 | + |
| 227 | + auto q_out = paddle::empty_like(q); |
| 228 | + auto k_out = paddle::empty_like(k); |
| 229 | + |
| 230 | + if (q.numel() == 0 || k.numel() == 0) { |
| 231 | + return {q_out, k_out}; |
| 232 | + } |
| 233 | + |
| 234 | + PADDLE_ENFORCE_EQ( |
| 235 | + q_shape.back() % 2, |
| 236 | + 0, |
| 237 | + "The last dimension (head_dim) of qk must be an even number " |
| 238 | + "for RoPE, but got %d", |
| 239 | + q_shape.back()); |
| 240 | + PADDLE_ENFORCE_EQ(q_shape.size(), |
| 241 | + cos_shape.size(), |
| 242 | + "The shape size of cos mismatches the shape size of q, " |
| 243 | + "expect %d but got %d", |
| 244 | + q_shape.size(), |
| 245 | + cos_shape.size()); |
| 246 | + PADDLE_ENFORCE_EQ(q_shape.back(), |
| 247 | + cos_shape.back(), |
| 248 | + "The shape.back() of cos mismatches the shape.back() of q, " |
| 249 | + "expect %d but got %d", |
| 250 | + q_shape.back(), |
| 251 | + cos_shape.back()); |
| 252 | + |
| 253 | + auto input_type = q.dtype(); |
| 254 | + switch (input_type) { |
| 255 | + case paddle::DataType::BFLOAT16: |
| 256 | + ApplyRopeKernel<paddle::DataType::BFLOAT16>( |
| 257 | + q, k, rot_cos, rot_sin, q_out, k_out); |
| 258 | + break; |
| 259 | + case paddle::DataType::FLOAT16: |
| 260 | + ApplyRopeKernel<paddle::DataType::FLOAT16>( |
| 261 | + q, k, rot_cos, rot_sin, q_out, k_out); |
| 262 | + break; |
| 263 | + default: |
| 264 | + PD_THROW("Only support qk dtype of BF16 and F16"); |
| 265 | + } |
| 266 | + |
| 267 | + return {q_out, k_out}; |
| 268 | +} |
| 269 | + |
| 270 | +std::vector<std::vector<int64_t>> ApplyRopeInferShape( |
| 271 | + const std::vector<int64_t>& q_shape, |
| 272 | + const std::vector<int64_t>& k_shape, |
| 273 | + const std::vector<int64_t>& cos_shape, |
| 274 | + const std::vector<int64_t>& sin_shape) { |
| 275 | + return {q_shape, k_shape, cos_shape, sin_shape}; |
| 276 | +} |
| 277 | + |
| 278 | +std::vector<paddle::DataType> ApplyRopeInferDtype( |
| 279 | + const paddle::DataType& q_dtype, |
| 280 | + const paddle::DataType& k_dtype, |
| 281 | + const paddle::DataType& cos_dtype, |
| 282 | + const paddle::DataType& sin_dtype) { |
| 283 | + return {q_dtype, k_dtype, cos_dtype, sin_dtype}; |
| 284 | +} |
| 285 | + |
| 286 | +PD_BUILD_OP(apply_rope) |
| 287 | + .Inputs({"q", "k", "rot_cos", "rot_sin"}) |
| 288 | + .Outputs({"q_out", "k_out"}) |
| 289 | + .SetKernelFn(PD_KERNEL(ApplyRope)) |
| 290 | + .SetInferShapeFn(PD_INFER_SHAPE(ApplyRopeInferShape)) |
| 291 | + .SetInferDtypeFn(PD_INFER_DTYPE(ApplyRopeInferDtype)); |
0 commit comments