-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
82 lines (61 loc) · 2.97 KB
/
model.py
File metadata and controls
82 lines (61 loc) · 2.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import random
import torch
import torch.nn as nn
class Encoder(nn.Module):
def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
super().__init__()
self.embedding = nn.Embedding(input_dim, emb_dim)
self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
self.dropout = nn.Dropout(dropout)
def forward(self, src):
# src: [seq_len, batch_size]
embedded = self.dropout(self.embedding(src))
# embedded: [seq_len, batch_size, emb_dim]
outputs, (hidden, cell) = self.rnn(embedded)
# return the hidden and cell states, they are needed for decoder as a start state
return hidden, cell
class Decoder(nn.Module):
def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
super().__init__()
self.output_dim = output_dim
self.embedding = nn.Embedding(output_dim, emb_dim)
self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
self.fc_out = nn.Linear(hid_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, input_tokens, hidden, cell):
# input_tokens: [batch_size] (один токен для каждого примера в батче)
# Добавляем измерение seq_len = 1
input_tokens = input_tokens.unsqueeze(0)
embedded = self.dropout(self.embedding(input_tokens))
# embedded: [1, batch_size, emb_dim]
output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
prediction = self.fc_out(output.squeeze(0))
# prediction: [batch_size, output_dim]
return prediction, hidden, cell
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, device):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.device = device
assert encoder.rnn.hidden_size == decoder.rnn.hidden_size, \
"Hidden dimensions of encoder and decoder must match!"
assert encoder.rnn.num_layers == decoder.rnn.num_layers, \
"Number of layers in encoder and decoder must match!"
def forward(self, src, trg, teacher_forcing_ratio=0.5):
# src: [src_len, batch_size]
# trg: [trg_len, batch_size]
batch_size = src.shape[1]
trg_len = trg.shape[0]
trg_vocab_size = self.decoder.output_dim
outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
hidden, cell = self.encoder(src)
input_tokens = trg[0, :]
for t in range(1, trg_len):
output, hidden, cell = self.decoder(input_tokens, hidden, cell)
outputs[t] = output
# Teacher forcing
teacher_force = random.random() < teacher_forcing_ratio
top1 = output.argmax(1)
input_tokens = trg[t] if teacher_force else top1
return outputs