-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
102 lines (81 loc) · 3.64 KB
/
train.py
File metadata and controls
102 lines (81 loc) · 3.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import numpy as np
import torch
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
from torch.utils.data import DataLoader
from typing import Tuple, List
from torch import nn, optim, Tensor
from torch.cuda.amp import autocast
def train_epoch(model: nn.Module, dataloader: DataLoader, loss_fn: nn.Module, optimizer: optim.Optimizer, device: str, scheduler, scaler=None) -> Tuple[float, float]:
model.train()
losses: List[float] = []
correct_predictions: int = 0
for batch in tqdm(dataloader, desc='Training', leave=True, position=0):
input_ids = batch['input_ids'].to(device)
labels = batch['label'].to(device)
optimizer.zero_grad(set_to_none=True)
if scaler is not None:
with autocast():
outputs = model(input_ids, mask=None)
loss = loss_fn(outputs, labels)
else:
outputs = model(input_ids, mask=None)
loss = loss_fn(outputs, labels)
losses.append(loss.item())
if scaler is None:
loss.backward()
clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
else:
scaler.scale(loss).backward()
clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
_, preds = torch.max(outputs, dim=1)
correct_predictions += torch.sum(preds == labels)
scheduler.step()
accuracy = correct_predictions.double() / len(dataloader.dataset)
average_loss = np.mean(losses)
return accuracy, average_loss
def eval_model(model: nn.Module, dataloader: DataLoader, loss_fn: nn.Module, device: torch.device) -> Tuple[float, float]:
model.eval()
losses: List[float] = []
correct_predictions: int = 0
with torch.no_grad():
for batch in tqdm(dataloader, desc='Validation', leave=True, position=0):
input_ids = batch['input_ids'].to(device)
labels = batch['label'].to(device)
if device == 'cuda':
with autocast(): # Enable mixed precision during evaluation
outputs = model(input_ids, mask=None)
loss = loss_fn(outputs, labels)
else:
outputs = model(input_ids, mask=None)
loss = loss_fn(outputs, labels)
_, preds = torch.max(outputs, dim=1)
loss = loss_fn(outputs, labels)
correct_predictions += torch.sum(preds == labels)
losses.append(loss.item())
accuracy = correct_predictions.double() / len(dataloader.dataset)
average_loss = np.mean(losses)
return accuracy, average_loss
def get_predictions(model: nn.Module, dataloader: DataLoader, device: torch.device) -> Tuple[Tensor, Tensor]:
model.eval()
predictions: List[Tensor] = []
real_values: List[Tensor] = []
with torch.no_grad():
for batch in tqdm(dataloader, desc='Testing', leave=False):
input_ids = batch['input_ids'].to(device)
labels = batch['label'].to(device)
if device == 'cuda':
with autocast():
outputs = model(input_ids, mask=None)
_, preds = torch.softmax(outputs, dim=1)
else:
outputs = model(input_ids, mask=None)
_, preds = torch.softmax(outputs, dim=1)
predictions.extend(preds)
real_values.extend(labels)
predictions = torch.stack(predictions).to(device)
real_values = torch.stack(real_values).to(device)
return predictions, real_values