diff --git a/custom_ops/gpu_ops/rebuild_padding.cu b/custom_ops/gpu_ops/rebuild_padding.cu index 369242d197b..e328c56a8fc 100644 --- a/custom_ops/gpu_ops/rebuild_padding.cu +++ b/custom_ops/gpu_ops/rebuild_padding.cu @@ -24,24 +24,23 @@ __global__ void RebuildPaddingKernel(T *output_data, const int max_input_length, const int dim_embed, const int elem_nums) { - using LoadT = AlignedVector; - LoadT src_vec; - const int global_idx = blockDim.x * blockIdx.x + threadIdx.x; - for (int i = global_idx * VecSize; i < elem_nums; - i += gridDim.x * blockDim.x * VecSize) { - const int bi = i / dim_embed; - const int bias_idx = i % dim_embed; - int seq_id = 0; - if (seq_len_this_time[bi] == 0) continue; - if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue; - if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1; + using LoadT = AlignedVector; + LoadT src_vec; + const int global_idx = blockDim.x * blockIdx.x + threadIdx.x; + for (int i = global_idx * VecSize; i < elem_nums; + i += gridDim.x * blockDim.x * VecSize) { + const int bi = i / dim_embed; + const int bias_idx = i % dim_embed; + int seq_id = 0; + if (seq_len_this_time[bi] == 0) continue; + if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue; + if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1; - const int ori_token_idx = - cu_seqlens_q[bi] + seq_id; - const int src_offset = ori_token_idx * dim_embed + bias_idx; - Load(&input_data[src_offset], &src_vec); - Store(src_vec, &output_data[i]); - } + const int ori_token_idx = cu_seqlens_q[bi] + seq_id; + const int src_offset = ori_token_idx * dim_embed + bias_idx; + Load(&input_data[src_offset], &src_vec); + Store(src_vec, &output_data[i]); + } } template @@ -58,41 +57,40 @@ __global__ void RebuildAppendPaddingKernel(T *output_data, const int64_t output_elem_nums, const int bsz, const bool enable_logprob) { - AlignedVector src_vec; - const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; - for (int64_t i = global_idx * VecSize; i < output_elem_nums; - i += gridDim.x * blockDim.x * VecSize) { - const int out_token_id = i / dim_embed; - const int ori_token_id = out_token_id + output_padding_offset[out_token_id]; + AlignedVector src_vec; + const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; + for (int64_t i = global_idx * VecSize; i < output_elem_nums; + i += gridDim.x * blockDim.x * VecSize) { + const int out_token_id = i / dim_embed; + const int ori_token_id = out_token_id + output_padding_offset[out_token_id]; - const int bi = ori_token_id / max_input_length; + const int bi = ori_token_id / max_input_length; - int seq_id = 0; - if (seq_len_this_time[bi] == 0) continue; - if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue; + int seq_id = 0; + if (seq_len_this_time[bi] == 0) continue; + if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue; - if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1; - const int cum_offset_bi = bi * max_input_length - cu_seqlens_q[bi]; - const int input_token_id = ori_token_id - cum_offset_bi + seq_id; - const int bias_idx = i % dim_embed; + if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1; + const int cum_offset_bi = bi * max_input_length - cu_seqlens_q[bi]; + const int input_token_id = ori_token_id - cum_offset_bi + seq_id; + const int bias_idx = i % dim_embed; - Load(&input_data[input_token_id * dim_embed + bias_idx], - &src_vec); - Store(src_vec, &output_data[i]); + Load(&input_data[input_token_id * dim_embed + bias_idx], + &src_vec); + Store(src_vec, &output_data[i]); - if (enable_logprob && seq_len_encoder[bi] > 0) { - const int first_input_token_id = input_token_id - 1; - Load(&input_data[first_input_token_id * dim_embed + bias_idx], - &src_vec); - Store(src_vec, &first_token_out[bi * dim_embed + bias_idx]); - } + if (enable_logprob && seq_len_encoder[bi] > 0) { + const int first_input_token_id = input_token_id - 1; + Load(&input_data[first_input_token_id * dim_embed + bias_idx], + &src_vec); + Store(src_vec, &first_token_out[bi * dim_embed + bias_idx]); } + } } - template std::vector rebuild_padding( - const paddle::Tensor &tmp_out, // [token_num, dim_embed] + const paddle::Tensor &tmp_out, // [token_num, dim_embed] const paddle::Tensor &cu_seqlens_q, // [bsz+1, 1] const paddle::Tensor &seq_len_this_time, const paddle::Tensor &seq_lens_decoder, @@ -101,84 +99,85 @@ std::vector rebuild_padding( const paddle::optional &first_token_out, int max_input_length, bool enable_logprob) { - typedef PDTraits traits_; - typedef typename traits_::DataType DataType_; - typedef typename traits_::data_t data_t; + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; #ifdef PADDLE_WITH_CUSTOM_DEVICE - auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(tmp_out.place())); - auto cu_stream = dev_ctx->stream(); + auto dev_ctx = static_cast( + paddle::experimental::DeviceContextPool::Instance().Get(tmp_out.place())); + auto cu_stream = dev_ctx->stream(); #else - auto cu_stream = tmp_out.stream(); + auto cu_stream = tmp_out.stream(); #endif - std::vector tmp_out_shape = tmp_out.shape(); - const int token_num = tmp_out_shape[0]; - const int dim_embed = tmp_out_shape[1]; - const int bsz = cu_seqlens_q.shape()[0] - 1; + std::vector tmp_out_shape = tmp_out.shape(); + const int token_num = tmp_out_shape[0]; + const int dim_embed = tmp_out_shape[1]; + const int bsz = cu_seqlens_q.shape()[0] - 1; - paddle::Tensor out; - if (output_padding_offset) { - int need_delete_token_num = 0; - auto seq_lens_encoder_cpu = - seq_lens_encoder.copy_to(paddle::CPUPlace(), true); - for (int i = 0; i < bsz; ++i) { - if (seq_lens_encoder_cpu.data()[i] > 0) { - need_delete_token_num += - seq_lens_encoder_cpu.data()[i] - 1; - } - } - out = paddle::full({token_num - need_delete_token_num, dim_embed}, - 0, - D, - tmp_out.place()); - } else { - out = - paddle::full({bsz, dim_embed}, 0, tmp_out.dtype(), tmp_out.place()); + paddle::Tensor out; + if (output_padding_offset) { + int need_delete_token_num = 0; + auto seq_lens_encoder_cpu = + seq_lens_encoder.copy_to(paddle::CPUPlace(), true); + for (int i = 0; i < bsz; ++i) { + if (seq_lens_encoder_cpu.data()[i] > 0) { + need_delete_token_num += seq_lens_encoder_cpu.data()[i] - 1; + } } + out = paddle::full( + {token_num - need_delete_token_num, dim_embed}, 0, D, tmp_out.place()); - constexpr int PackSize = VEC_16B / sizeof(DataType_); - int elem_nums = out.numel(); - int pack_num = elem_nums / PackSize; - const int blocksize = 128; - const int grid_size = (pack_num + blocksize - 1) / blocksize; - if (output_padding_offset) { - RebuildAppendPaddingKernel - <<>>( - reinterpret_cast(out.data()), - first_token_out.is_initialized() - ? reinterpret_cast(const_cast( - first_token_out.get_ptr()->data())) - : nullptr, - reinterpret_cast(tmp_out.data()), - cu_seqlens_q.data(), - seq_len_this_time.data(), - seq_lens_decoder.data(), - seq_lens_encoder.data(), - output_padding_offset.get_ptr()->data(), - max_input_length, - dim_embed, - elem_nums, - bsz, - enable_logprob); - } else { - RebuildPaddingKernel - <<>>( - reinterpret_cast(out.data()), - reinterpret_cast( - const_cast(tmp_out.data())), - cu_seqlens_q.data(), - seq_len_this_time.data(), - seq_lens_decoder.data(), - seq_lens_encoder.data(), - max_input_length, - dim_embed, - elem_nums); - } - return {out}; + PADDLE_ENFORCE(out.shape()[0] == output_padding_offset.get().shape()[0], + "Unmatched shape"); + + } else { + out = paddle::full({bsz, dim_embed}, 0, tmp_out.dtype(), tmp_out.place()); + } + + constexpr int PackSize = VEC_16B / sizeof(DataType_); + int elem_nums = out.numel(); + int pack_num = elem_nums / PackSize; + const int blocksize = 128; + const int grid_size = (pack_num + blocksize - 1) / blocksize; + if (output_padding_offset) { + RebuildAppendPaddingKernel + <<>>( + reinterpret_cast(out.data()), + first_token_out.is_initialized() + ? reinterpret_cast(const_cast( + first_token_out.get_ptr()->data())) + : nullptr, + reinterpret_cast(tmp_out.data()), + cu_seqlens_q.data(), + seq_len_this_time.data(), + seq_lens_decoder.data(), + seq_lens_encoder.data(), + output_padding_offset.get_ptr()->data(), + max_input_length, + dim_embed, + elem_nums, + bsz, + enable_logprob); + } else { + RebuildPaddingKernel + <<>>( + reinterpret_cast(out.data()), + reinterpret_cast( + const_cast(tmp_out.data())), + cu_seqlens_q.data(), + seq_len_this_time.data(), + seq_lens_decoder.data(), + seq_lens_encoder.data(), + max_input_length, + dim_embed, + elem_nums); + } + return {out}; } paddle::Tensor RebuildPaddingFunc( - const paddle::Tensor &tmp_out, // [token_num, dim_embed] + const paddle::Tensor &tmp_out, // [token_num, dim_embed] const paddle::Tensor &cu_seqlens_q, // [bsz+1, 1] const paddle::Tensor &seq_len_this_time, const paddle::Tensor &seq_lens_decoder, @@ -187,55 +186,52 @@ paddle::Tensor RebuildPaddingFunc( const paddle::optional &first_token_out, int max_input_length, bool enable_logprob) { - switch (tmp_out.type()) { - case paddle::DataType::BFLOAT16: { - return rebuild_padding( - tmp_out, - cu_seqlens_q, - seq_len_this_time, - seq_lens_decoder, - seq_lens_encoder, - output_padding_offset, - first_token_out, - max_input_length, - enable_logprob)[0]; - } - case paddle::DataType::FLOAT16: { - return rebuild_padding( - tmp_out, - cu_seqlens_q, - seq_len_this_time, - seq_lens_decoder, - seq_lens_encoder, - output_padding_offset, - first_token_out, - max_input_length, - enable_logprob)[0]; - } - case paddle::DataType::FLOAT32: { - return rebuild_padding( - tmp_out, - cu_seqlens_q, - seq_len_this_time, - seq_lens_decoder, - seq_lens_encoder, - output_padding_offset, - first_token_out, - max_input_length, - enable_logprob)[0]; - } - default: { - PD_THROW( - "NOT supported data type. " - "Only float16, bfloat16 and float32 are supported. "); - break; - } + switch (tmp_out.type()) { + case paddle::DataType::BFLOAT16: { + return rebuild_padding(tmp_out, + cu_seqlens_q, + seq_len_this_time, + seq_lens_decoder, + seq_lens_encoder, + output_padding_offset, + first_token_out, + max_input_length, + enable_logprob)[0]; } + case paddle::DataType::FLOAT16: { + return rebuild_padding(tmp_out, + cu_seqlens_q, + seq_len_this_time, + seq_lens_decoder, + seq_lens_encoder, + output_padding_offset, + first_token_out, + max_input_length, + enable_logprob)[0]; + } + case paddle::DataType::FLOAT32: { + return rebuild_padding(tmp_out, + cu_seqlens_q, + seq_len_this_time, + seq_lens_decoder, + seq_lens_encoder, + output_padding_offset, + first_token_out, + max_input_length, + enable_logprob)[0]; + } + default: { + PD_THROW( + "NOT supported data type. " + "Only float16, bfloat16 and float32 are supported. "); + break; + } + } } std::vector RebuildPadding( - const paddle::Tensor &tmp_out, // [token_num, dim_embed] - const paddle::Tensor &cu_seqlens_q, // [bsz+1, 1] + const paddle::Tensor &tmp_out, // [token_num, dim_embed] + const paddle::Tensor &cu_seqlens_q, // [bsz+1, 1] const paddle::Tensor &seq_len_this_time, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, @@ -243,15 +239,15 @@ std::vector RebuildPadding( const paddle::optional &first_token_out, int max_input_length, bool enable_logprob) { - return {RebuildPaddingFunc(tmp_out, - cu_seqlens_q, - seq_len_this_time, - seq_lens_decoder, - seq_lens_encoder, - output_padding_offset, - first_token_out, - max_input_length, - enable_logprob)}; + return {RebuildPaddingFunc(tmp_out, + cu_seqlens_q, + seq_len_this_time, + seq_lens_decoder, + seq_lens_encoder, + output_padding_offset, + first_token_out, + max_input_length, + enable_logprob)}; } std::vector> RebuildPaddingInferShape( @@ -261,14 +257,14 @@ std::vector> RebuildPaddingInferShape( const std::vector &seq_lens_decoder_shape, const std::vector &seq_lens_encoder_shape, const paddle::optional> &output_padding_offset_shape) { - int64_t dim_embed = tmp_out_shape[1]; - // whether speculative decoding - if (output_padding_offset_shape) { - return {{-1, dim_embed}}; - } else { - int64_t bsz = cu_seqlens_q_shape[0] - 1; - return {{bsz, dim_embed}}; - } + int64_t dim_embed = tmp_out_shape[1]; + // whether speculative decoding + if (output_padding_offset_shape) { + return {{-1, dim_embed}}; + } else { + int64_t bsz = cu_seqlens_q_shape[0] - 1; + return {{bsz, dim_embed}}; + } } std::vector RebuildPaddingInferDtype( @@ -278,7 +274,7 @@ std::vector RebuildPaddingInferDtype( const paddle::DataType &seq_lens_decoder_dtype, const paddle::DataType &seq_lens_encoder_dtype, const paddle::optional &output_padding_offset_dtype) { - return {tmp_out_dtype}; + return {tmp_out_dtype}; } PD_BUILD_STATIC_OP(rebuild_padding) diff --git a/tests/operators/test_rebuild_padding.py b/tests/operators/test_rebuild_padding.py index d87c4b203a8..6b8db57016a 100644 --- a/tests/operators/test_rebuild_padding.py +++ b/tests/operators/test_rebuild_padding.py @@ -168,7 +168,7 @@ def test_rebuild_padding_no_offset(self): # test with offset def test_rebuild_padding_with_offset(self): paddle.seed(42) - token_num = 100 + token_num = 84 dim_embed = 256 # bsz = 4 max_input_length = 512 @@ -184,7 +184,7 @@ def test_rebuild_padding_with_offset(self): seq_lens_encoder = np.array([0, 20, 0, 20, 0, 20, 0, 20], dtype=np.int32) seq_lens_decoder = np.array([21, 0, 21, 0, 21, 0, 21, 0], dtype=np.int32) - num_output_tokens = 80 + num_output_tokens = 8 output_padding_offset = np.random.randint(0, 10, [num_output_tokens], dtype=np.int32) out_with_offset_ref = rebuild_padding_ref( tmp_out=tmp_out,