Skip to content

Commit e794f5b

Browse files
committed
[Metax] optimize cutlass moe and flash attention backend
1 parent 3c8c0f0 commit e794f5b

File tree

5 files changed

+469
-161
lines changed

5 files changed

+469
-161
lines changed

custom_ops/metax_ops/apply_rope.cu

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <cuda_runtime.h>
16+
#include <paddle/extension.h>
17+
#include <algorithm>
18+
#include "helper.h"
19+
20+
#define THREADS_PER_BLOCK 128
21+
22+
template <typename T>
23+
struct Converter;
24+
25+
template <>
26+
struct Converter<__half> {
27+
// __half -> float
28+
__device__ static float to_float(__half val) { return __half2float(val); }
29+
// float -> __half
30+
__device__ static __half from_float(float val) {
31+
return __float2half_rn(val);
32+
}
33+
// int -> __half
34+
__device__ static __half from_int(float val) { return __int2half_rn(val); }
35+
};
36+
37+
template <>
38+
struct Converter<__nv_bfloat16> {
39+
// __nv_bfloat16 -> float
40+
__device__ static float to_float(__nv_bfloat16 val) {
41+
return __bfloat162float(val);
42+
}
43+
// float -> __nv_bfloat16
44+
__device__ static __nv_bfloat16 from_float(float val) {
45+
return __float2bfloat16_rn(val);
46+
}
47+
// int -> __nv_bfloat16
48+
__device__ static __nv_bfloat16 from_int(int val) {
49+
return __int2bfloat16_rn(val);
50+
}
51+
};
52+
53+
template <typename T>
54+
__device__ void RotateQKVec4(const T* qk_ptr,
55+
const T* rot_cos_ptr,
56+
const T* rot_sin_ptr,
57+
const int head_num,
58+
const int base_idx,
59+
const int rot_base_idx,
60+
T* out) {
61+
using VecT = AlignedVector<T, 4>;
62+
63+
VecT qk_vec;
64+
Load(qk_ptr + base_idx, &qk_vec);
65+
VecT rot_half_vec = {-qk_vec[1], qk_vec[0], -qk_vec[3], qk_vec[2]};
66+
VecT cos_vec, sin_vec;
67+
Load(rot_cos_ptr + rot_base_idx, &cos_vec);
68+
Load(rot_sin_ptr + rot_base_idx, &sin_vec);
69+
#pragma unroll
70+
for (int i = 0; i < 4; ++i) {
71+
*(out + base_idx + i) =
72+
qk_vec[i] * cos_vec[i] + rot_half_vec[i] * sin_vec[i];
73+
}
74+
}
75+
76+
template <typename T>
77+
__device__ void RotateQKVec4(const T* qk_ptr,
78+
const float* rot_cos_ptr,
79+
const float* rot_sin_ptr,
80+
const int head_num,
81+
const int base_idx,
82+
const int rot_base_idx,
83+
T* out) {
84+
using VecT = AlignedVector<T, 4>;
85+
using VecF = AlignedVector<float, 4>;
86+
auto to_float = [] __device__(T val) -> float {
87+
return Converter<T>::to_float(val);
88+
};
89+
auto from_float = [] __device__(float val) -> T {
90+
return Converter<T>::from_float(val);
91+
};
92+
93+
VecT qk_vec;
94+
Load(qk_ptr + base_idx, &qk_vec);
95+
VecF rot_half_vec = {-to_float(qk_vec[1]),
96+
to_float(qk_vec[0]),
97+
-to_float(qk_vec[3]),
98+
to_float(qk_vec[2])};
99+
VecF cos_vec, sin_vec;
100+
Load(rot_cos_ptr + rot_base_idx, &cos_vec);
101+
Load(rot_sin_ptr + rot_base_idx, &sin_vec);
102+
#pragma unroll
103+
for (int i = 0; i < 4; ++i) {
104+
*(out + base_idx + i) = from_float(to_float(qk_vec[i]) * cos_vec[i] +
105+
rot_half_vec[i] * sin_vec[i]);
106+
}
107+
}
108+
109+
// qk and rope have a same type
110+
template <typename T>
111+
__global__ void DispatchApplyRopeVec4Kernel(const T* q,
112+
const T* k,
113+
const T* rot_cos,
114+
const T* rot_sin,
115+
const int q_num_elements,
116+
const int k_num_elements,
117+
const int q_head_num,
118+
const int k_head_num,
119+
const int head_dim,
120+
T* q_out,
121+
T* k_out) {
122+
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
123+
int head_dim_idx = idx % head_dim;
124+
125+
if (idx < q_num_elements) {
126+
int rot_idx = idx / (q_head_num * head_dim) * head_dim + head_dim_idx;
127+
RotateQKVec4(q, rot_cos, rot_sin, q_head_num, idx, rot_idx, q_out);
128+
}
129+
130+
if (idx < k_num_elements) {
131+
int rot_idx = idx / (k_head_num * head_dim) * head_dim + head_dim_idx;
132+
RotateQKVec4(k, rot_cos, rot_sin, k_head_num, idx, rot_idx, k_out);
133+
}
134+
}
135+
136+
// rope dtype is float32
137+
template <typename T>
138+
__global__ void DispatchApplyRopeVec4Kernel(const T* q,
139+
const T* k,
140+
const float* rot_cos,
141+
const float* rot_sin,
142+
const int q_num_elements,
143+
const int k_num_elements,
144+
const int q_head_num,
145+
const int k_head_num,
146+
const int head_dim,
147+
T* q_out,
148+
T* k_out) {
149+
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
150+
int head_dim_idx = idx % head_dim;
151+
152+
if (idx < q_num_elements) {
153+
int rot_idx = idx / (q_head_num * head_dim) * head_dim + head_dim_idx;
154+
RotateQKVec4(q, rot_cos, rot_sin, q_head_num, idx, rot_idx, q_out);
155+
}
156+
157+
if (idx < k_num_elements) {
158+
int rot_idx = idx / (k_head_num * head_dim) * head_dim + head_dim_idx;
159+
RotateQKVec4(k, rot_cos, rot_sin, k_head_num, idx, rot_idx, k_out);
160+
}
161+
}
162+
163+
template <paddle::DataType D>
164+
void ApplyRopeKernel(const paddle::Tensor& q,
165+
const paddle::Tensor& k,
166+
const paddle::Tensor& rot_cos,
167+
const paddle::Tensor& rot_sin,
168+
paddle::Tensor& q_out,
169+
paddle::Tensor& k_out) {
170+
typedef PDTraits<D> traits_;
171+
typedef typename traits_::DataType DataType_;
172+
typedef typename traits_::data_t data_t;
173+
174+
const auto q_num_elements = q.numel();
175+
const auto k_num_elements = k.numel();
176+
const auto q_shape = q.shape();
177+
const auto k_shape = k.shape();
178+
const auto dims = q_shape.size();
179+
const auto q_head_num = q_shape[dims - 2];
180+
const auto k_head_num = k_shape[dims - 2];
181+
const auto head_dim = q_shape.back();
182+
int block_num =
183+
(std::max(q_num_elements, k_num_elements) + (THREADS_PER_BLOCK * 4) - 1) /
184+
(THREADS_PER_BLOCK * 4);
185+
auto stream = q.stream();
186+
187+
if (q.dtype() == rot_cos.dtype()) {
188+
DispatchApplyRopeVec4Kernel<DataType_>
189+
<<<block_num, THREADS_PER_BLOCK, 0, stream>>>(
190+
reinterpret_cast<const DataType_*>(q.data<data_t>()),
191+
reinterpret_cast<const DataType_*>(k.data<data_t>()),
192+
reinterpret_cast<const DataType_*>(rot_cos.data<data_t>()),
193+
reinterpret_cast<const DataType_*>(rot_sin.data<data_t>()),
194+
q_num_elements,
195+
k_num_elements,
196+
q_head_num,
197+
k_head_num,
198+
head_dim,
199+
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
200+
reinterpret_cast<DataType_*>(k_out.data<data_t>()));
201+
} else if (rot_cos.dtype() == paddle::DataType::FLOAT32) {
202+
DispatchApplyRopeVec4Kernel<DataType_>
203+
<<<block_num, THREADS_PER_BLOCK, 0, stream>>>(
204+
reinterpret_cast<const DataType_*>(q.data<data_t>()),
205+
reinterpret_cast<const DataType_*>(k.data<data_t>()),
206+
reinterpret_cast<const float*>(rot_cos.data<float>()),
207+
reinterpret_cast<const float*>(rot_sin.data<float>()),
208+
q_num_elements,
209+
k_num_elements,
210+
q_head_num,
211+
k_head_num,
212+
head_dim,
213+
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
214+
reinterpret_cast<DataType_*>(k_out.data<data_t>()));
215+
} else {
216+
PD_THROW("Unsupported qk dtype and rope dtype.");
217+
}
218+
}
219+
220+
std::vector<paddle::Tensor> ApplyRope(const paddle::Tensor& q,
221+
const paddle::Tensor& k,
222+
const paddle::Tensor& rot_cos,
223+
const paddle::Tensor& rot_sin) {
224+
auto q_shape = q.shape();
225+
auto cos_shape = rot_cos.shape();
226+
227+
auto q_out = paddle::empty_like(q);
228+
auto k_out = paddle::empty_like(k);
229+
230+
if (q.numel() == 0 || k.numel() == 0) {
231+
return {q_out, k_out};
232+
}
233+
234+
PADDLE_ENFORCE_EQ(
235+
q_shape.back() % 2,
236+
0,
237+
"The last dimension (head_dim) of qk must be an even number "
238+
"for RoPE, but got %d",
239+
q_shape.back());
240+
PADDLE_ENFORCE_EQ(q_shape.size(),
241+
cos_shape.size(),
242+
"The shape size of cos mismatches the shape size of q, "
243+
"expect %d but got %d",
244+
q_shape.size(),
245+
cos_shape.size());
246+
PADDLE_ENFORCE_EQ(q_shape.back(),
247+
cos_shape.back(),
248+
"The shape.back() of cos mismatches the shape.back() of q, "
249+
"expect %d but got %d",
250+
q_shape.back(),
251+
cos_shape.back());
252+
253+
auto input_type = q.dtype();
254+
switch (input_type) {
255+
case paddle::DataType::BFLOAT16:
256+
ApplyRopeKernel<paddle::DataType::BFLOAT16>(
257+
q, k, rot_cos, rot_sin, q_out, k_out);
258+
break;
259+
case paddle::DataType::FLOAT16:
260+
ApplyRopeKernel<paddle::DataType::FLOAT16>(
261+
q, k, rot_cos, rot_sin, q_out, k_out);
262+
break;
263+
default:
264+
PD_THROW("Only support qk dtype of BF16 and F16");
265+
}
266+
267+
return {q_out, k_out};
268+
}
269+
270+
std::vector<std::vector<int64_t>> ApplyRopeInferShape(
271+
const std::vector<int64_t>& q_shape,
272+
const std::vector<int64_t>& k_shape,
273+
const std::vector<int64_t>& cos_shape,
274+
const std::vector<int64_t>& sin_shape) {
275+
return {q_shape, k_shape, cos_shape, sin_shape};
276+
}
277+
278+
std::vector<paddle::DataType> ApplyRopeInferDtype(
279+
const paddle::DataType& q_dtype,
280+
const paddle::DataType& k_dtype,
281+
const paddle::DataType& cos_dtype,
282+
const paddle::DataType& sin_dtype) {
283+
return {q_dtype, k_dtype, cos_dtype, sin_dtype};
284+
}
285+
286+
PD_BUILD_OP(apply_rope)
287+
.Inputs({"q", "k", "rot_cos", "rot_sin"})
288+
.Outputs({"q_out", "k_out"})
289+
.SetKernelFn(PD_KERNEL(ApplyRope))
290+
.SetInferShapeFn(PD_INFER_SHAPE(ApplyRopeInferShape))
291+
.SetInferDtypeFn(PD_INFER_DTYPE(ApplyRopeInferDtype));

0 commit comments

Comments
 (0)