Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,8 @@ void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &is_block_step,
const paddle::Tensor &stop_nums);
const paddle::Tensor &stop_nums,
const paddle::Tensor &mask_rollback);

void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
const paddle::Tensor &accept_tokens,
Expand Down Expand Up @@ -934,6 +935,17 @@ void SpeculateGetTargetLogits(const paddle::Tensor &target_logits,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &accept_num);

std::vector<paddle::Tensor> UpdateAttnMaskOffsets(const paddle::Tensor& ids_remove_padding,
const paddle::Tensor& seq_lens_this_time, // only on cpu
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& attn_mask_offsets_full,
const paddle::Tensor& attn_mask_offsets_decoder,
const paddle::Tensor& is_block_step,
const paddle::Tensor& decode_states,
const paddle::Tensor& mask_rollback);

PYBIND11_MODULE(fastdeploy_ops, m) {

m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
Expand Down Expand Up @@ -1328,4 +1340,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("speculate_insert_first_token", &SpeculateInsertFirstToken, "speculate_insert_first_token function");

m.def("speculate_get_target_logits", &SpeculateGetTargetLogits, "speculate_get_target_logits function");

m.def("update_attn_mask_offsets", &UpdateAttnMaskOffsets, "update attention mask");
}
21 changes: 15 additions & 6 deletions custom_ops/gpu_ops/speculate_decoding/speculate_update.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ __global__ void speculate_update(int *seq_lens_encoder,
const int *seq_lens_this_time,
const bool *is_block_step,
const int64_t *stop_nums,
int *mask_rollback,
const int real_bsz,
const int max_bsz,
const int max_draft_tokens) {
Expand All @@ -35,9 +36,12 @@ __global__ void speculate_update(int *seq_lens_encoder,
if (!(is_block_step[bid] || bid >= real_bsz)) {
if (stop_flags[bid]) {
stop_flag_now_int = 1;
}
if (seq_lens_encoder[bid] == 0) {
mask_rollback[bid] = 0;
} else if (seq_lens_encoder[bid] == 0) { // decoder
seq_lens_decoder[bid] += accept_num_now;
mask_rollback[bid] = seq_lens_this_time[bid] - accept_num_now;
} else { // encoder
mask_rollback[bid] = 0;
}

if (seq_lens_this_time[bid] > 1 &&
Expand Down Expand Up @@ -97,7 +101,8 @@ void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &is_block_step,
const paddle::Tensor &stop_nums) {
const paddle::Tensor &stop_nums,
const paddle::Tensor &mask_rollback) {
const int real_bsz = seq_lens_this_time.shape()[0];
const int max_bsz = stop_flags.shape()[0];
auto max_draft_tokens = draft_tokens.shape()[1];
Expand All @@ -117,6 +122,7 @@ void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
seq_lens_this_time.data<int>(),
is_block_step.data<bool>(),
stop_nums.data<int64_t>(),
const_cast<int *>(mask_rollback.data<int>()),
real_bsz,
max_bsz,
max_draft_tokens);
Expand All @@ -138,15 +144,18 @@ PD_BUILD_STATIC_OP(speculate_update)
"stop_flags",
"seq_lens_this_time",
"is_block_step",
"stop_nums"})
"stop_nums",
"mask_rollback"})
.Outputs({"seq_lens_encoder_out",
"seq_lens_decoder_out",
"not_need_stop_out",
"draft_tokens_out",
"actual_draft_token_nums_out"})
"actual_draft_token_nums_out",
"mask_rollback_out"})
.SetInplaceMap({{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"not_need_stop", "not_need_stop_out"},
{"draft_tokens", "draft_tokens_out"},
{"actual_draft_token_nums", "actual_draft_token_nums_out"}})
{"actual_draft_token_nums", "actual_draft_token_nums_out"},
{"mask_rollback", "mask_rollback_out"}})
.SetKernelFn(PD_KERNEL(SpeculateUpdate));
126 changes: 126 additions & 0 deletions custom_ops/gpu_ops/speculate_decoding/update_attn_mask_offsets.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"
#include "paddle/extension.h"

__global__ void update_attn_mask_offsets_kernel(int* attn_mask_offsets,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int* cu_seqlens_q,
const int* attn_mask_offsets_full,
int* attn_mask_offsets_decoder,
const bool* is_block_step,
int* decode_states,
const int* mask_rollback,
const int real_bsz,
const int max_model_len,
const int decode_states_len){
int tid = threadIdx.x;

for (int bid = tid; bid < real_bsz; bid += blockDim.x) {
int seq_len_this_time = seq_lens_this_time[bid];
int seq_len_encoder = seq_lens_encoder[bid];
int seq_len_decoder = seq_lens_decoder[bid];
int query_start_id = cu_seqlens_q[bid];

const int* attn_mask_offsets_full_now = attn_mask_offsets_full + bid * max_model_len;
int* decode_states_now = decode_states + bid * decode_states_len;
if (!is_block_step[bid]) {
if (seq_len_encoder == 0 && seq_len_decoder == 0) {
// Status: stop
} else if (seq_len_encoder > 0) {
// Status: prefill -- normal or chunk_prefill
for (int i = 0; i < seq_len_this_time; i++) {
attn_mask_offsets[query_start_id + i] = attn_mask_offsets_full_now[i];
}
} else if (seq_len_decoder > 0) {
// Status: decoder -- normal or chunk_prefill
// TODO: support speculative decoding.
attn_mask_offsets_decoder[bid] -= mask_rollback[bid];

for (int i = 0; i < seq_len_this_time; i++) {
attn_mask_offsets[query_start_id + i] = attn_mask_offsets_decoder[bid] + i;
}
attn_mask_offsets_decoder[bid] += seq_len_this_time;

// Speculative decoding in text_generation
if (seq_len_this_time > 1) {
for (int i = 0; i < decode_states_len; i++) {
if (i < seq_len_this_time) {
decode_states_now[i] = 0;
} else {
decode_states_now[i] = -1;
}
}
}
}
}
}
}

std::vector<paddle::Tensor> UpdateAttnMaskOffsets(const paddle::Tensor& ids_remove_padding,
const paddle::Tensor& seq_lens_this_time, // only on cpu
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& attn_mask_offsets_full,
const paddle::Tensor& attn_mask_offsets_decoder,
const paddle::Tensor& is_block_step,
const paddle::Tensor& decode_states,
const paddle::Tensor& mask_rollback) {
int max_model_len = attn_mask_offsets_full.shape()[1];
int real_bsz = seq_lens_this_time.shape()[0];
int batch_seq_lens = ids_remove_padding.shape()[0];
int decode_states_len = decode_states.shape()[1];

auto attn_mask_offsets = paddle::empty(
{batch_seq_lens}, paddle::DataType::INT32, ids_remove_padding.place());

// launch config
int blockSize = 512;

update_attn_mask_offsets_kernel<<<1, blockSize, 0, ids_remove_padding.stream()>>>(
attn_mask_offsets.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
cu_seqlens_q.data<int>(),
attn_mask_offsets_full.data<int>(),
const_cast<int*>(attn_mask_offsets_decoder.data<int>()),
is_block_step.data<bool>(),
const_cast<int*>(decode_states.data<int>()),
mask_rollback.data<int>(),
real_bsz,
max_model_len,
decode_states_len);

return {attn_mask_offsets};
}

PD_BUILD_STATIC_OP(update_attn_mask_offsets)
.Inputs({"ids_remove_padding",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"cu_seqlens_q",
"attn_mask_offsets_full",
"attn_mask_offsets_decoder",
"is_block_step",
"decode_states",
"mask_rollback"})
.Outputs({"attn_mask_offsets", "decode_states_out"})
.SetInplaceMap({{"decode_states", "decode_states_out"}})
.SetKernelFn(PD_KERNEL(UpdateAttnMaskOffsets));
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def __init__(
self.rope_3d = False
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method: str = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
Expand Down Expand Up @@ -369,7 +368,7 @@ def forward_mixed(
getattr(layer, "cache_v_zp", None),
layer.linear_shift,
layer.linear_smooth,
None if self.use_speculate else forward_meta.attn_mask_offsets,
forward_meta.attn_mask_offsets,
metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
Expand All @@ -387,7 +386,7 @@ def forward_mixed(
metadata.max_partition_size,
metadata.encoder_max_partition_size,
self.speculate_max_draft_token_num + 1,
self.causal or self.use_speculate,
self.causal,
self.speculative_method is not None,
)
return res
1 change: 1 addition & 0 deletions fastdeploy/model_executor/pre_and_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def post_process_specualate(
model_output.seq_lens_this_time,
model_output.is_block_step,
model_output.stop_nums,
model_output.mask_rollback,
)

if not skip_save_output:
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/spec_decode/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def __init__(self, fd_config: FDConfig):
self.max_ngram_size = self.speculative_config.max_ngram_size
self.min_ngram_size = self.speculative_config.min_ngram_size

self.enable_mm = self.model_config.enable_mm

spec_logger.info(f"Speculate config: {self.speculative_config}")

def run(self, *args, **kwargs) -> Any:
Expand Down
42 changes: 42 additions & 0 deletions fastdeploy/spec_decode/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
share_external_data,
speculate_get_logits,
speculate_save_output_topk,
update_attn_mask_offsets,
)
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding

Expand Down Expand Up @@ -415,6 +416,21 @@ def _init_model_inputs(self):
self.model_inputs["cu_next_token_offset"] = paddle.full(
shape=[self.max_num_seqs + 1], fill_value=0, dtype="int32"
)
self.model_inputs["mask_rollback"] = paddle.full([self.max_num_seqs, 1], 0, dtype="int32")
# attn_mask
if self.enable_mm:
self.model_inputs["attn_mask_offsets"] = paddle.full(
shape=[self.max_num_seqs * self.max_model_len], fill_value=-1, dtype="int32"
)
self.model_inputs["attn_mask_offsets_full"] = paddle.full(
[self.max_num_seqs, self.max_model_len], -1, dtype="int32"
)
self.model_inputs["attn_mask_offsets_decoder"] = paddle.full([self.max_num_seqs, 1], -1, dtype="int32")
self.model_inputs["decode_states"] = paddle.full(
[self.max_num_seqs, self.max_draft_token_num + 1],
-1,
dtype="int32",
)

def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):

Expand Down Expand Up @@ -453,6 +469,16 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
self.model_inputs["step_idx"][idx : idx + 1] = (
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
)
if self.enable_mm:
inputs = request.multimodal_inputs
self.model_inputs["attn_mask_offsets_full"][idx][0 : prefill_end_index - prefill_start_index] = (
paddle.to_tensor(
inputs["attention_mask_offset"][prefill_start_index:prefill_end_index], dtype="int32"
)
)
self.model_inputs["attn_mask_offsets_decoder"][idx : idx + 1] = (
inputs["attention_mask_offset"][prefill_end_index - 1] + 1
)

# has_prefill_task = True
elif request.task_type.value == RequestType.DECODE.value: # decode task
Expand Down Expand Up @@ -585,6 +611,7 @@ def _initialize_forward_meta(self, step_use_cudagraph: bool = False):
cu_seqlens_k=self.model_inputs["cu_seqlens_k"],
block_tables=self.model_inputs["block_tables"],
caches=self.model_inputs["caches"],
attn_mask_offsets=self.model_inputs["attn_mask_offsets"] if self.enable_mm else None,
)

# Initialzie attention meta data
Expand Down Expand Up @@ -705,6 +732,21 @@ def _propose(self, step_use_cudagraph: bool = False):
self.model_inputs["seq_lens_decoder"],
)

if self.enable_mm:
attn_mask_offsets = update_attn_mask_offsets(
ids_remove_padding,
getattr(self.model_inputs, "seq_lens_this_time", self.seq_lens_this_time_buffer),
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
cu_seqlens_q,
self.model_inputs["attn_mask_offsets_full"],
self.model_inputs["attn_mask_offsets_decoder"],
self.model_inputs["is_block_step"],
self.model_inputs["decode_states"],
self.model_inputs["mask_rollback"],
)[0]
self.model_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False)

# Initialize forward meta data
self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
self.model_inputs["batch_id_per_token"][:] = -1
Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,8 @@ def _init_share_inputs(self, max_num_seqs: int):
)
self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")

self.share_inputs["mask_rollback"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")

def _prepare_inputs(self) -> None:
"""Prepare the model inputs"""
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
Expand Down Expand Up @@ -1249,6 +1251,7 @@ def _dummy_run(
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
prompt_lens=self.share_inputs["prompt_lens"],
mask_rollback=self.share_inputs["mask_rollback"],
)

post_process(
Expand Down Expand Up @@ -1591,6 +1594,7 @@ class at the server level, which is too granular for ModelRunner.
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
prompt_lens=self.share_inputs["prompt_lens"],
mask_rollback=self.share_inputs["mask_rollback"],
)

if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill":
Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/worker/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,11 @@ class ModelOutputData:
"""
prompt_lens: paddle.Tensor = None

"""
step mask rollback in some cases
"""
mask_rollback: paddle.Tensor = None


@dataclass
class ModelRunnerOutput:
Expand Down
Loading
Loading