forked from AkaliKong/MiniOneRec
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLogitProcessor.py
More file actions
63 lines (50 loc) · 2.39 KB
/
LogitProcessor.py
File metadata and controls
63 lines (50 loc) · 2.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from transformers.generation import LogitsProcessor
from transformers import AutoTokenizer
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
import math
import numpy as np
import torch
from transformers.utils import add_start_docstrings
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
search or log softmax for each vocabulary token when using beam search
Return:
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
"""
class ConstrainedLogitsProcessor(LogitsProcessor):
def __init__(
self,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
num_beams: int,
base_model: str = None
):
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
self._num_beams = num_beams
self.count=0
self.base_model = base_model
if self.base_model.lower().find("gpt2") > -1:
self.prefix_index = 4
else:
self.prefix_index = 3
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores = torch.nn.functional.log_softmax(scores, dim=-1)
mask = torch.full_like(scores, -1000000)
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
for beam_id, sent in enumerate(beam_sent):
if self.count == 0:
hash_key = sent[-self.prefix_index:]
else:
hash_key=sent[-self.count:]
hash_key = hash_key.tolist()
prefix_allowed_tokens = self._prefix_allowed_tokens_fn(batch_id, hash_key)
if len(prefix_allowed_tokens) == 0:
continue
mask[batch_id * self._num_beams + beam_id, prefix_allowed_tokens] = 0
self.count += 1
scores = scores + mask
return scores