-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclassify.py
More file actions
162 lines (133 loc) · 5.04 KB
/
classify.py
File metadata and controls
162 lines (133 loc) · 5.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import sys
import argparse
import torch
import torch.nn as nn
from torchtext import data
import torchtext
from my_ntc.models.rnn import RNNClassifier
from my_ntc.models.cnn import CNNClassifier
version = list(map(int, torchtext.__version__.split('.')))
if version[0] <= 0 and version[1] < 9:
from torchtext import data
else:
from torchtext.legacy import data
def define_argparser():
'''
Define argument parser to take inference using pre-trained model.
'''
p = argparse.ArgumentParser()
p.add_argument('--model_fn', required=True)
p.add_argument('--gpu_id', type=int, default=-1)
p.add_argument('--batch_size', type=int, default=256)
p.add_argument('--top_k', type=int, default=1)
p.add_argument('--max_length', type=int, default=256)
p.add_argument('--drop_rnn', action='store_true') # store_true : 인자를 적으면(값을 주지 않는다) 해당 인자에 true나 false가 저장된다.
# store_true의 경우 default 값은 false이며, 인자를 적어 주면 true가 저장된다.
# store_false의 경우 반대이다.
p.add_argument('--drop_cnn', action='store_true')
config = p.parse_args()
return config
def read_text(max_length=256):
'''
Read text from standard input for inference.
'''
lines = []
for line in sys.stdin:
if line.strip() != '':
lines += [line.strip().split(' ')[:max_length]]
return lines
def define_field():
'''
To avoid use DataLoader class, just declare dummy fields.
With those fields, we can retore mapping table between words and indice.
'''
return (
data.Field(
use_vocab=True,
batch_first=True,
include_lengths=False,
),
data.Field(
sequential=False,
use_vocab=True,
unk_token=None,
)
)
def main(config):
saved_data = torch.load(
config.model_fn,
map_location='cpu' if config.gpu_id < 0 else 'cuda:%d' % config.gpu_id
)
# print(saved_data) # check saved model
train_config = saved_data['config']
rnn_best = saved_data['rnn']
cnn_best = saved_data['cnn']
vocab = saved_data['vocab']
classes = saved_data['classes']
vocab_size = len(vocab)
n_classes = len(classes)
text_field, label_field = define_field()
text_field.vocab = vocab
label_field.vocab = classes
lines = read_text(max_length=config.max_length)
with torch.no_grad():
ensemble = []
if rnn_best is not None and not config.drop_rnn:
# Declare model and load pre-trained weights.
model = RNNClassifier(
input_size=vocab_size,
word_vec_size=train_config.word_vec_size,
hidden_size=train_config.hidden_size,
n_classes=n_classes,
n_layers=train_config.n_layers,
dropout_p=train_config.dropout,
)
model.load_state_dict(rnn_best)
ensemble += [model]
if cnn_best is not None and not config.drop_cnn:
# Declare model and load pre-trained weights.
model = CNNClassifier(
input_size=vocab_size,
word_vec_size=train_config.word_vec_size,
n_classes=n_classes,
use_batch_norm=train_config.use_batch_norm,
dropout_p=train_config.dropout,
window_sizes=train_config.window_sizes,
n_filters=train_config.n_filters,
)
model.load_state_dict(cnn_best)
ensemble += [model]
y_hats = []
# Get prediction with iteration on ensemble.
for model in ensemble:
if config.gpu_id >= 0:
model.cuda(config.gpu_id)
# Don't forget turn-on evaluation mode.
model.eval()
y_hat = []
for idx in range(0, len(lines), config.batch_size):
# Converts string to list of index.
x = text_field.numericalize(
text_field.pad(lines[idx:idx + config.batch_size]),
device='cuda:%d' % config.gpu_id if config.gpu_id >= 0 else 'cpu',
)
y_hat += [model(x).cpu()]
# Concatenate the mini-batch wise result
y_hat = torch.cat(y_hat, dim=0)
# |y_hat| = (len(lines), n_classes)
y_hats += [y_hat]
model.cpu()
# Merge to one tensor for ensemble result and make probability from log-prob.
y_hats = torch.stack(y_hats).exp()
# |y_hats| = (len(ensemble), len(lines), n_classes)
y_hats = y_hats.sum(dim=0) / len(ensemble) # Get average
# |y_hats| = (len(lines), n_classes)
probs, indice = y_hats.topk(config.top_k)
for i in range(len(lines)):
sys.stdout.write('%s\t%s\n' % (
' '.join([classes.itos[indice[i][j]] for j in range(config.top_k)]),
' '.join(lines[i])
))
if __name__ == '__main__':
config = define_argparser()
main(config)