-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredictions.py
More file actions
35 lines (25 loc) · 1.15 KB
/
predictions.py
File metadata and controls
35 lines (25 loc) · 1.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from preprocessing import clean_text, tokenize, encode_tokens, pad_sequence
from model import SentimentCNNBiLSTM
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vocab = torch.load('models/vocab.pth', map_location=torch.device(device))
label_encoder = torch.load('models/label_encoder.pth', map_location=torch.device(device))
model = torch.load('models/sentiment_model.pth', map_location=torch.device(device))
while True:
prompt = input("Enter a sentence: ")
if prompt == 'exit':
break
else:
prompt = clean_text(prompt)
prompt = tokenize(prompt)
prompt = encode_tokens(prompt, vocab)
prompt = pad_sequence(prompt, 128)
tensor = torch.tensor(prompt, dtype=torch.long).unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
output = model(tensor)
probabilities = torch.softmax(output, dim=1)
_, predicted_class = torch.max(output, dim=1)
predicted_sentiment = predicted_class.item()
predicted_sentiment_label = label_encoder.classes_[predicted_sentiment] - 1
print(f"Predicted Sentiment: {predicted_sentiment+1}")