-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
51 lines (36 loc) · 1.93 KB
/
model.py
File metadata and controls
51 lines (36 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
from torch import nn
from torch.nn.functional import relu
vocab = torch.load('models/vocab.pth')
class Attention(nn.Module):
def __init__(self, lstm_hidden_dim):
super(Attention, self).__init__()
self.attention = nn.Linear(lstm_hidden_dim * 2, 1)
def forward(self, lstm_out):
attention_weights = torch.softmax(self.attention(lstm_out), dim=1)
attended = torch.sum(lstm_out * attention_weights, dim=1)
return attended
class SentimentCNNBiLSTM(nn.Module):
def __init__(self, vocab_size, embedding_dim, conv_filters, lstm_hidden_dim, output_dim, dropout):
super(SentimentCNNBiLSTM, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=vocab['<PAD>'])
self.conv1d_1 = nn.Conv1d(in_channels=embedding_dim, out_channels=conv_filters, kernel_size=3, padding=2)
self.conv1d_2 = nn.Conv1d(in_channels=conv_filters, out_channels=conv_filters, kernel_size=5, padding=2)
self.conv1d_3 = nn.Conv1d(in_channels=conv_filters, out_channels=conv_filters, kernel_size=7, padding=2)
self.maxpool = nn.MaxPool1d(kernel_size=2)
self.bilstm = nn.LSTM(conv_filters, lstm_hidden_dim, num_layers=2, bidirectional=True, batch_first=True)
self.attention = Attention(lstm_hidden_dim)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(lstm_hidden_dim * 2, output_dim)
def forward(self, input_ids):
embedded = self.embedding(input_ids)
embedded = embedded.permute(0, 2, 1)
conv_out = relu(self.conv1d_1(embedded))
conv_out = relu(self.conv1d_2(conv_out))
conv_out = relu(self.conv1d_3(conv_out))
pooled_out = self.maxpool(conv_out).permute(0, 2, 1)
lstm_out, _ = self.bilstm(pooled_out)
attended = self.attention(lstm_out)
attended = self.dropout(attended)
output = self.fc(attended)
return output