Skip to content

Commit 95ed77f

Browse files
committed
delete AscendRejectionSampler
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
1 parent 081c1eb commit 95ed77f

File tree

3 files changed

+7
-10
lines changed

3 files changed

+7
-10
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import vllm.v1.sample.rejection_sampler as rs
2+
from vllm_ascend.sample.rejection_sampler import expand_batch_to_tokens, rejection_sample
3+
4+
rs.expand_batch_to_tokens = expand_batch_to_tokens
5+
rs.rejection_sample = rejection_sample

vllm_ascend/sample/rejection_sampler.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@
1818
MAX_SPEC_LEN = 32
1919

2020

21-
class AscendRejectionSampler(RejectionSampler, nn.Module):
22-
pass
23-
24-
2521
def rejection_sample(
2622
# [num_tokens]
2723
draft_token_ids: torch.Tensor,
@@ -695,7 +691,3 @@ def sample_recovered_tokens_kernel(
695691
tl.store(
696692
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
697693
orig_prob)
698-
699-
700-
rs.expand_batch_to_tokens = expand_batch_to_tokens
701-
rs.rejection_sample = rejection_sample

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
make_empty_encoder_model_runner_output)
9898
from vllm.v1.pool.metadata import PoolingMetadata
9999
from vllm.v1.sample.metadata import SamplingMetadata
100+
from vllm.v1.sample.rejection_sampler import RejectionSampler
100101
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
101102
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
102103
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
@@ -137,7 +138,6 @@
137138
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
138139
from vllm_ascend.platform import NPUPlatform
139140
from vllm_ascend.sample.logits_processor import build_logitsprocs
140-
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
141141
from vllm_ascend.spec_decode import get_spec_decode_method
142142
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
143143
from vllm_ascend.spec_decode.interface import SpecDcodeType
@@ -634,7 +634,7 @@ def _set_up_drafter(self):
634634
diagonal=1).to(self.device)
635635
if get_pp_group().is_last_rank:
636636
self.drafter = self._get_drafter()
637-
self.rejection_sampler = AscendRejectionSampler(self.sampler)
637+
self.rejection_sampler = RejectionSampler(self.sampler)
638638
self.actual_seq_lengths_q = list(
639639
range(self.decode_token_per_req, self.max_num_tokens + 1,
640640
self.decode_token_per_req))

0 commit comments

Comments
 (0)