forked from timoschick/pet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpreprocessor.py
More file actions
90 lines (66 loc) · 3.55 KB
/
preprocessor.py
File metadata and controls
90 lines (66 loc) · 3.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from abc import ABC, abstractmethod
from utils import InputFeatures, InputExample
from pvp import AgnewsPVP, MnliPVP, YelpPolarityPVP, YelpFullPVP, \
YahooPVP, PVP, XStancePVP
PVPS = {
'agnews': AgnewsPVP,
'mnli': MnliPVP,
'yelp-polarity': YelpPolarityPVP,
'yelp-full': YelpFullPVP,
'yahoo': YahooPVP,
'xstance': XStancePVP,
'xstance-de': XStancePVP,
'xstance-fr': XStancePVP,
}
class Preprocessor(ABC):
def __init__(self, wrapper, task_name, pattern_id: int = 0, verbalizer_file: str = None):
self.wrapper = wrapper
self.pvp = PVPS[task_name](self.wrapper, pattern_id, verbalizer_file) # type: PVP
self.label_map = {label: i for i, label in enumerate(self.wrapper.config.label_list)}
@abstractmethod
def get_input_features(self, example: InputExample, labelled: bool, **kwargs) -> InputFeatures:
pass
class MLMPreprocessor(Preprocessor):
def get_input_features(self, example: InputExample, labelled: bool, **kwargs) -> InputFeatures:
input_ids, token_type_ids = self.pvp.encode(example)
attention_mask = [1] * len(input_ids)
padding_length = self.wrapper.config.max_seq_length - len(input_ids)
input_ids = input_ids + ([self.wrapper.tokenizer.pad_token_id] * padding_length)
attention_mask = attention_mask + ([0] * padding_length)
token_type_ids = token_type_ids + ([0] * padding_length)
assert len(input_ids) == self.wrapper.config.max_seq_length
assert len(attention_mask) == self.wrapper.config.max_seq_length
assert len(token_type_ids) == self.wrapper.config.max_seq_length
label = self.label_map[example.label]
logits = example.logits if example.logits else [-1]
if labelled:
mlm_labels = self.pvp.get_mask_positions(input_ids)
else:
mlm_labels = [-1] * self.wrapper.config.max_seq_length
return InputFeatures(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
label=label, mlm_labels=mlm_labels, logits=logits)
class SequenceClassifierPreprocessor(Preprocessor):
def get_input_features(self, example: InputExample, **kwargs) -> InputFeatures:
inputs = self.wrapper.tokenizer.encode_plus(
example.text_a if example.text_a else None,
example.text_b if example.text_b else None,
add_special_tokens=True,
max_length=self.wrapper.config.max_seq_length,
)
input_ids, token_type_ids = inputs["input_ids"], inputs.get("token_type_ids")
attention_mask = [1] * len(input_ids)
padding_length = self.wrapper.config.max_seq_length - len(input_ids)
input_ids = input_ids + ([self.wrapper.tokenizer.pad_token_id] * padding_length)
attention_mask = attention_mask + ([0] * padding_length)
if not token_type_ids:
token_type_ids = [0] * self.wrapper.config.max_seq_length
else:
token_type_ids = token_type_ids + ([0] * padding_length)
mlm_labels = [-1] * len(input_ids)
assert len(input_ids) == self.wrapper.config.max_seq_length
assert len(attention_mask) == self.wrapper.config.max_seq_length
assert len(token_type_ids) == self.wrapper.config.max_seq_length
label = self.label_map[example.label]
logits = example.logits if example.logits else [-1]
return InputFeatures(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
label=label, mlm_labels=mlm_labels, logits=logits)