From ced9c87e2478ff3d941f7fb59cf2ae35405a422c Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Sun, 26 Oct 2025 00:27:27 +0800 Subject: [PATCH 1/4] support mask_offset in speculate decoding --- custom_ops/gpu_ops/cpp_extensions.cc | 16 ++- .../speculate_decoding/speculate_update.cu | 21 ++- .../update_attn_mask_offsets.cu | 126 ++++++++++++++++++ .../layers/attention/append_attn_backend.py | 5 +- .../model_executor/pre_and_post_process.py | 1 + fastdeploy/spec_decode/base.py | 2 + fastdeploy/spec_decode/mtp.py | 42 ++++++ fastdeploy/worker/gpu_model_runner.py | 3 + fastdeploy/worker/output.py | 5 + tests/operators/test_tree_mask.py | 26 +++- 10 files changed, 234 insertions(+), 13 deletions(-) create mode 100644 custom_ops/gpu_ops/speculate_decoding/update_attn_mask_offsets.cu diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 956805802f9..9e2978d08da 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -700,7 +700,8 @@ void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &stop_flags, const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &is_block_step, - const paddle::Tensor &stop_nums); + const paddle::Tensor &stop_nums, + const paddle::Tensor &mask_rollback); void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, const paddle::Tensor &accept_tokens, @@ -934,6 +935,17 @@ void SpeculateGetTargetLogits(const paddle::Tensor &target_logits, const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &accept_num); +std::vector UpdateAttnMaskOffsets(const paddle::Tensor& ids_remove_padding, + const paddle::Tensor& seq_lens_this_time, // only on cpu + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& attn_mask_offsets_full, + const paddle::Tensor& attn_mask_offsets_decoder, + const paddle::Tensor& is_block_step, + const paddle::Tensor& decode_states, + const paddle::Tensor& mask_rollback); + PYBIND11_MODULE(fastdeploy_ops, m) { m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"), @@ -1328,4 +1340,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("speculate_insert_first_token", &SpeculateInsertFirstToken, "speculate_insert_first_token function"); m.def("speculate_get_target_logits", &SpeculateGetTargetLogits, "speculate_get_target_logits function"); + + m.def("update_attn_mask_offsets", &UpdateAttnMaskOffsets, "update attention mask"); } diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_update.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_update.cu index 828dc17285f..ff3b249c7e0 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_update.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_update.cu @@ -26,6 +26,7 @@ __global__ void speculate_update(int *seq_lens_encoder, const int *seq_lens_this_time, const bool *is_block_step, const int64_t *stop_nums, + int *mask_rollback, const int real_bsz, const int max_bsz, const int max_draft_tokens) { @@ -35,9 +36,12 @@ __global__ void speculate_update(int *seq_lens_encoder, if (!(is_block_step[bid] || bid >= real_bsz)) { if (stop_flags[bid]) { stop_flag_now_int = 1; - } - if (seq_lens_encoder[bid] == 0) { + mask_rollback[bid] = 0; + } else if (seq_lens_encoder[bid] == 0) { // decoder seq_lens_decoder[bid] += accept_num_now; + mask_rollback[bid] = seq_lens_this_time[bid] - accept_num_now; + } else { // encoder + mask_rollback[bid] = 0; } if (seq_lens_this_time[bid] > 1 && @@ -97,7 +101,8 @@ void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &stop_flags, const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &is_block_step, - const paddle::Tensor &stop_nums) { + const paddle::Tensor &stop_nums, + const paddle::Tensor &mask_rollback) { const int real_bsz = seq_lens_this_time.shape()[0]; const int max_bsz = stop_flags.shape()[0]; auto max_draft_tokens = draft_tokens.shape()[1]; @@ -117,6 +122,7 @@ void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder, seq_lens_this_time.data(), is_block_step.data(), stop_nums.data(), + const_cast(mask_rollback.data()), real_bsz, max_bsz, max_draft_tokens); @@ -138,15 +144,18 @@ PD_BUILD_STATIC_OP(speculate_update) "stop_flags", "seq_lens_this_time", "is_block_step", - "stop_nums"}) + "stop_nums", + "mask_rollback"}) .Outputs({"seq_lens_encoder_out", "seq_lens_decoder_out", "not_need_stop_out", "draft_tokens_out", - "actual_draft_token_nums_out"}) + "actual_draft_token_nums_out", + "mask_rollback_out"}) .SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"}, {"seq_lens_decoder", "seq_lens_decoder_out"}, {"not_need_stop", "not_need_stop_out"}, {"draft_tokens", "draft_tokens_out"}, - {"actual_draft_token_nums", "actual_draft_token_nums_out"}}) + {"actual_draft_token_nums", "actual_draft_token_nums_out"}, + {"mask_rollback", "mask_rollback_out"}}) .SetKernelFn(PD_KERNEL(SpeculateUpdate)); diff --git a/custom_ops/gpu_ops/speculate_decoding/update_attn_mask_offsets.cu b/custom_ops/gpu_ops/speculate_decoding/update_attn_mask_offsets.cu new file mode 100644 index 00000000000..24e839f805d --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/update_attn_mask_offsets.cu @@ -0,0 +1,126 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" + +__global__ void update_attn_mask_offsets_kernel(int* attn_mask_offsets, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int* cu_seqlens_q, + const int* attn_mask_offsets_full, + int* attn_mask_offsets_decoder, + const bool* is_block_step, + int* decode_states, + const int* mask_rollback, + const int real_bsz, + const int max_model_len, + const int decode_states_len){ + int tid = threadIdx.x; + + for (int bid = tid; bid < real_bsz; bid += blockDim.x) { + int seq_len_this_time = seq_lens_this_time[bid]; + int seq_len_encoder = seq_lens_encoder[bid]; + int seq_len_decoder = seq_lens_decoder[bid]; + int query_start_id = cu_seqlens_q[bid]; + + const int* attn_mask_offsets_full_now = attn_mask_offsets_full + bid * max_model_len; + int* decode_states_now = decode_states + bid * decode_states_len; + if (!is_block_step[bid]) { + if (seq_len_encoder == 0 && seq_len_decoder == 0) { + // Status: stop + } else if (seq_len_encoder > 0) { + // Status: prefill -- normal or chunk_prefill + for (int i = 0; i < seq_len_this_time; i++) { + attn_mask_offsets[query_start_id + i] = attn_mask_offsets_full_now[i]; + } + } else if (seq_len_decoder > 0) { + // Status: decoder -- normal or chunk_prefill + // TODO: support speculative decoding. + attn_mask_offsets_decoder[bid] -= mask_rollback[bid]; + + for (int i = 0; i < seq_len_this_time; i++) { + attn_mask_offsets[query_start_id + i] = attn_mask_offsets_decoder[bid] + i; + } + attn_mask_offsets_decoder[bid] += seq_len_this_time; + + // Speculative decoding in text_generation + if (seq_len_this_time > 1) { + for (int i = 0; i < decode_states_len; i++) { + if (i < seq_len_this_time) { + decode_states_now[i] = 0; + } else { + decode_states_now[i] = -1; + } + } + } + } + } + } +} + +std::vector UpdateAttnMaskOffsets(const paddle::Tensor& ids_remove_padding, + const paddle::Tensor& seq_lens_this_time, // only on cpu + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& attn_mask_offsets_full, + const paddle::Tensor& attn_mask_offsets_decoder, + const paddle::Tensor& is_block_step, + const paddle::Tensor& decode_states, + const paddle::Tensor& mask_rollback) { + int max_model_len = attn_mask_offsets_full.shape()[1]; + int real_bsz = seq_lens_this_time.shape()[0]; + int batch_seq_lens = ids_remove_padding.shape()[0]; + int decode_states_len = decode_states.shape()[1]; + + auto attn_mask_offsets = paddle::empty( + {batch_seq_lens}, paddle::DataType::INT32, ids_remove_padding.place()); + + // launch config + int blockSize = 512; + + update_attn_mask_offsets_kernel<<<1, blockSize, 0, ids_remove_padding.stream()>>>( + attn_mask_offsets.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + cu_seqlens_q.data(), + attn_mask_offsets_full.data(), + const_cast(attn_mask_offsets_decoder.data()), + is_block_step.data(), + const_cast(decode_states.data()), + mask_rollback.data(), + real_bsz, + max_model_len, + decode_states_len); + + return {attn_mask_offsets}; +} + +PD_BUILD_STATIC_OP(update_attn_mask_offsets) + .Inputs({"ids_remove_padding", + "seq_lens_this_time", + "seq_lens_encoder", + "seq_lens_decoder", + "cu_seqlens_q", + "attn_mask_offsets_full", + "attn_mask_offsets_decoder", + "is_block_step", + "decode_states", + "mask_rollback"}) + .Outputs({"attn_mask_offsets", "decode_states_out"}) + .SetInplaceMap({{"decode_states", "decode_states_out"}}) + .SetKernelFn(PD_KERNEL(UpdateAttnMaskOffsets)); diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 090615d6e7f..463744fc3b6 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -104,7 +104,6 @@ def __init__( self.rope_3d = False self.causal: bool = getattr(fd_config.model_config, "causal", True) self.speculative_method: str = fd_config.speculative_config.method - self.use_speculate: bool = self.speculative_method is not None self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"]) @@ -369,7 +368,7 @@ def forward_mixed( getattr(layer, "cache_v_zp", None), layer.linear_shift, layer.linear_smooth, - None if self.use_speculate else forward_meta.attn_mask_offsets, + forward_meta.attn_mask_offsets, metadata.kv_signal_data_list[layer.layer_id], getattr(layer, "q_norm_weight", None), getattr(layer, "k_norm_weight", None), @@ -387,7 +386,7 @@ def forward_mixed( metadata.max_partition_size, metadata.encoder_max_partition_size, self.speculate_max_draft_token_num + 1, - self.causal or self.use_speculate, + self.causal, self.speculative_method is not None, ) return res diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index d41b2d674bd..54218ffb8a7 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -325,6 +325,7 @@ def post_process_specualate( model_output.seq_lens_this_time, model_output.is_block_step, model_output.stop_nums, + model_output.mask_rollback, ) if not skip_save_output: diff --git a/fastdeploy/spec_decode/base.py b/fastdeploy/spec_decode/base.py index 1b8f9838481..89b9dba824e 100644 --- a/fastdeploy/spec_decode/base.py +++ b/fastdeploy/spec_decode/base.py @@ -70,6 +70,8 @@ def __init__(self, fd_config: FDConfig): self.max_ngram_size = self.speculative_config.max_ngram_size self.min_ngram_size = self.speculative_config.min_ngram_size + self.enable_mm = self.model_config.enable_mm + spec_logger.info(f"Speculate config: {self.speculative_config}") def run(self, *args, **kwargs) -> Any: diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index de7c645f48f..728cd70da15 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -46,6 +46,7 @@ share_external_data, speculate_get_logits, speculate_save_output_topk, + update_attn_mask_offsets, ) from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding @@ -415,6 +416,21 @@ def _init_model_inputs(self): self.model_inputs["cu_next_token_offset"] = paddle.full( shape=[self.max_num_seqs + 1], fill_value=0, dtype="int32" ) + self.model_inputs["mask_rollback"] = paddle.full([self.max_num_seqs, 1], 0, dtype="int32") + # attn_mask + if self.enable_mm: + self.model_inputs["attn_mask_offsets"] = paddle.full( + shape=[self.max_num_seqs * self.max_model_len], fill_value=-1, dtype="int32" + ) + self.model_inputs["attn_mask_offsets_full"] = paddle.full( + [self.max_num_seqs, self.max_model_len], -1, dtype="int32" + ) + self.model_inputs["attn_mask_offsets_decoder"] = paddle.full([self.max_num_seqs, 1], -1, dtype="int32") + self.model_inputs["decode_states"] = paddle.full( + [self.max_num_seqs, self.max_draft_token_num + 1], + -1, + dtype="int32", + ) def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int): @@ -453,6 +469,16 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int): self.model_inputs["step_idx"][idx : idx + 1] = ( len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 ) + if self.enable_mm: + inputs = request.multimodal_inputs + self.model_inputs["attn_mask_offsets_full"][idx][0 : prefill_end_index - prefill_start_index] = ( + paddle.to_tensor( + inputs["attention_mask_offset"][prefill_start_index:prefill_end_index], dtype="int32" + ) + ) + self.model_inputs["attn_mask_offsets_decoder"][idx : idx + 1] = ( + inputs["attention_mask_offset"][prefill_end_index - 1] + 1 + ) # has_prefill_task = True elif request.task_type.value == RequestType.DECODE.value: # decode task @@ -585,6 +611,7 @@ def _initialize_forward_meta(self, step_use_cudagraph: bool = False): cu_seqlens_k=self.model_inputs["cu_seqlens_k"], block_tables=self.model_inputs["block_tables"], caches=self.model_inputs["caches"], + attn_mask_offsets=self.model_inputs["attn_mask_offsets"] if self.enable_mm else None, ) # Initialzie attention meta data @@ -705,6 +732,21 @@ def _propose(self, step_use_cudagraph: bool = False): self.model_inputs["seq_lens_decoder"], ) + if self.enable_mm: + attn_mask_offsets = update_attn_mask_offsets( + ids_remove_padding, + getattr(self.model_inputs, "seq_lens_this_time", self.seq_lens_this_time_buffer), + self.model_inputs["seq_lens_encoder"], + self.model_inputs["seq_lens_decoder"], + cu_seqlens_q, + self.model_inputs["attn_mask_offsets_full"], + self.model_inputs["attn_mask_offsets_decoder"], + self.model_inputs["is_block_step"], + self.model_inputs["decode_states"], + self.model_inputs["mask_rollback"], + )[0] + self.model_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False) + # Initialize forward meta data self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) self.model_inputs["batch_id_per_token"][:] = -1 diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 9e0f053634c..4ac734b3977 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -848,6 +848,8 @@ def _init_share_inputs(self, max_num_seqs: int): ) self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") + self.share_inputs["mask_rollback"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") + def _prepare_inputs(self) -> None: """Prepare the model inputs""" if envs.ENABLE_V1_KVCACHE_SCHEDULER: @@ -1591,6 +1593,7 @@ class at the server level, which is too granular for ModelRunner. stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], prompt_lens=self.share_inputs["prompt_lens"], + mask_rollback=self.share_inputs["mask_rollback"], ) if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill": diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 1128062814a..052ff01bafd 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -257,6 +257,11 @@ class ModelOutputData: """ prompt_lens: paddle.Tensor = None + """ + step mask rollback in some cases + """ + mask_rollback: paddle.Tensor = None + @dataclass class ModelRunnerOutput: diff --git a/tests/operators/test_tree_mask.py b/tests/operators/test_tree_mask.py index 408162809f2..cd114829c4d 100644 --- a/tests/operators/test_tree_mask.py +++ b/tests/operators/test_tree_mask.py @@ -140,7 +140,9 @@ def ref_attention(self, q, k, v, mask, use_qknorm=False): .reshape([-1, self.num_q_head, self.head_dim]) ) - def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None, use_qknorm=False): + def run_append_c16_attention( + self, q_len, kv_len, prefill=False, attn_mask=None, use_qknorm=False, mask_offset=None + ): if prefill: seq_lens_enc = [ q_len, @@ -264,7 +266,7 @@ def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None, None, # cache_v_zp None, # linear_shift None, # linear_smooth - None, # mask_offset + mask_offset, # mask_offset None, # kv_signal_data self.q_norm_weight_tensor if use_qknorm else None, # q_norm_weight self.k_norm_weight_tensor if use_qknorm else None, # k_norm_weight @@ -282,7 +284,7 @@ def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None, self.max_partition_size, self.encoder_max_partition_size, decoder_step_token_num, - True, + True if mask_offset is None else False, decoder_step_token_num > 1, ) paddle.device.synchronize() @@ -353,6 +355,24 @@ def test_tree_mask(self): ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 ) + def test_mask_offset(self): + prefill_len = 8192 + dec_len_q = 5 + total_len = prefill_len + dec_len_q + mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) + mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) + self.run_append_c16_attention(prefill_len, 0, True, use_qknorm=self.use_qknorm) + + mask_offset = paddle.tile(paddle.arange(prefill_len, prefill_len + dec_len_q), [self.bsz]).astype("int32") + dec_out = self.run_append_c16_attention( + dec_len_q, prefill_len, False, use_qknorm=self.use_qknorm, mask_offset=mask_offset + ) + + ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask, use_qknorm=self.use_qknorm) + np.testing.assert_allclose( + ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 + ) + if __name__ == "__main__": unittest.main() From aacbf4d324367c2e9171352590aa2b2798da3cbc Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Mon, 27 Oct 2025 01:27:22 +0800 Subject: [PATCH 2/4] fix dummpy run output --- fastdeploy/worker/gpu_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 4ac734b3977..bc3f79dec82 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1251,6 +1251,7 @@ def _dummy_run( stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], prompt_lens=self.share_inputs["prompt_lens"], + mask_rollback=self.share_inputs["mask_rollback"], ) post_process( From f6e852b7c44e136fcb71fa8010f7e958ae0b701c Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Mon, 27 Oct 2025 11:05:59 +0800 Subject: [PATCH 3/4] add unit test --- tests/operators/test_update_attn_mask.py | 111 +++++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 tests/operators/test_update_attn_mask.py diff --git a/tests/operators/test_update_attn_mask.py b/tests/operators/test_update_attn_mask.py new file mode 100644 index 00000000000..cd1d741ea59 --- /dev/null +++ b/tests/operators/test_update_attn_mask.py @@ -0,0 +1,111 @@ +import numpy as np +import paddle +import pytest +from mm_custom_ops import update_attn_mask_offsets + + +def run_update_attn_mask_offsets_case( + seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, is_block_step, max_model_len=8, decode_states_len=4 +): + bsz = len(seq_lens_this_time) + + # cu_seqlens_q: 累积和 + cu_seqlens_q = np.zeros(bsz, dtype="int32") + cu_seqlens_q[1:] = np.cumsum(seq_lens_this_time[:-1]) + cu_seqlens_q = paddle.to_tensor(cu_seqlens_q, dtype="int32") + print("cu_seqlens_q", cu_seqlens_q) + # ids_remove_padding 只是用来确定 batch_seq_lens + ids_remove_padding = paddle.randint(low=0, high=10, shape=[sum(seq_lens_this_time)], dtype="int32") + + # attention_mask: (bsz, max_model_len) + attention_mask = paddle.arange(bsz * max_model_len, dtype="int32").reshape([bsz, max_model_len]) + + # 每个 batch 一个 decoder offset + attention_mask_decoder = paddle.zeros([bsz], dtype="int32") + + attention_mask_decoder[:] = paddle.to_tensor(seq_lens_decoder, dtype="int32") + + # decode_states: (bsz, decode_states_len) + decode_states = paddle.full([bsz, decode_states_len], -1, dtype="int32") + + mask_rollback = paddle.full([bsz, 1], 0, dtype="int32") + + # 调用 op + attn_mask_offsets = update_attn_mask_offsets( + ids_remove_padding, + paddle.to_tensor(seq_lens_this_time, dtype="int32"), + paddle.to_tensor(seq_lens_encoder, dtype="int32"), + paddle.to_tensor(seq_lens_decoder, dtype="int32"), + cu_seqlens_q, + attention_mask, + attention_mask_decoder, + paddle.to_tensor(is_block_step, dtype="bool"), + decode_states, + mask_rollback, + ) + if isinstance(attn_mask_offsets, list): + attn_mask_offsets = attn_mask_offsets[0] + return attn_mask_offsets.numpy(), decode_states.numpy() + + +def test_stop_case(): + attn_mask_offsets, _ = run_update_attn_mask_offsets_case( + seq_lens_this_time=[2], + seq_lens_encoder=[0], + seq_lens_decoder=[0], + is_block_step=[False], + ) + # stop 场景不应更新 + assert np.all(attn_mask_offsets == 0) or np.allclose(attn_mask_offsets, 0) + + +def test_prefill_case(): + attn_mask_offsets, _ = run_update_attn_mask_offsets_case( + seq_lens_this_time=[5], + seq_lens_encoder=[5], + seq_lens_decoder=[0], + is_block_step=[False], + ) + # 应该拷贝了 attention_mask 的一部分,不全是 0 + assert np.allclose(attn_mask_offsets, np.arange(0, 5)) + + +def test_decoder_case(): + attn_mask_offsets, decode_states_out = run_update_attn_mask_offsets_case( + seq_lens_this_time=[3], + seq_lens_encoder=[0], + seq_lens_decoder=[2], + is_block_step=[False], + ) + # decoder 场景 attn_mask_offsets 应该有非零更新 + assert np.allclose(attn_mask_offsets, np.array([2, 3, 4])) + # decode_states 前面部分应该被重置为 0 + assert np.any(decode_states_out == 0) + + +def test_non_block_step_case(): + attn_mask_offsets, _ = run_update_attn_mask_offsets_case( + seq_lens_this_time=[0, 2], + seq_lens_encoder=[0, 0], + seq_lens_decoder=[0, 20], + is_block_step=[True, False], + ) + # 进入 block step,Query 1 不应该被写入 + assert np.allclose(attn_mask_offsets, np.array([20, 21])) + + +def test_mixed_batch_case(): + attn_mask_offsets, decode_states_out = run_update_attn_mask_offsets_case( + seq_lens_this_time=[2, 5, 1], + seq_lens_encoder=[0, 5, 0], + seq_lens_decoder=[2, 0, 2], + is_block_step=[False, False, False], + ) + # batch 混合场景,至少部分更新 + assert attn_mask_offsets.shape[0] == sum([2, 5, 1]) + assert np.allclose(attn_mask_offsets, np.array([2, 3, 8, 9, 10, 11, 12, 2])) + assert decode_states_out.shape[1] == 4 + + +if __name__ == "__main__": + pytest.main([__file__]) From ffe92c61cc2acbd4a82c87278d17ebe5358d4f06 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Mon, 27 Oct 2025 14:43:27 +0800 Subject: [PATCH 4/4] fix unit test import --- tests/operators/test_update_attn_mask.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/operators/test_update_attn_mask.py b/tests/operators/test_update_attn_mask.py index cd1d741ea59..f77eef28674 100644 --- a/tests/operators/test_update_attn_mask.py +++ b/tests/operators/test_update_attn_mask.py @@ -1,7 +1,8 @@ import numpy as np import paddle import pytest -from mm_custom_ops import update_attn_mask_offsets + +from fastdeploy.model_executor.ops.gpu import update_attn_mask_offsets def run_update_attn_mask_offsets_case(