-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
98 lines (73 loc) · 3.16 KB
/
train.py
File metadata and controls
98 lines (73 loc) · 3.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import sys
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
from command_dataset import CommandDataset
from model import SpeechRecognitionModel
from data_processing import preprocess_model
EPOCHS = 25
MAX_LEARNING_RATE = 5e-4
BATCH_SIZE = 2
ROOT_DIR = './data/'
CSV_FILE = f'{ROOT_DIR}command_labels.csv'
MODEL_PATH = './command_model.pth'
CHECKPOINT_PATH = './command_model_chkpnt.pth'
MIN_LOSS = 10
def printout(_str):
print(_str)
sys.stdout.flush()
def train(model, device, train_loader, criterion, optimizer, scheduler, epoch):
global MIN_LOSS
model.train()
data_len = len(train_loader.dataset)
for batch_idx, _data in enumerate(train_loader):
spectrograms, labels, input_lengths, label_lengths = _data
spectrograms, labels = spectrograms.to(device), labels.to(device)
optimizer.zero_grad()
output = model(spectrograms) # (batch, time, n_class)
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
loss = criterion(output, labels, input_lengths, label_lengths)
loss.backward()
optimizer.step()
scheduler.step()
if batch_idx % 1 == 0 or batch_idx == data_len:
printout('\nTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(spectrograms), data_len,
100. * batch_idx / len(train_loader), loss.item()))
if loss.item() < MIN_LOSS:
# Saving model checkpoint
printout(f'Saving model checkpoint with min loss of {loss.item()}...')
torch.save(model.state_dict(), CHECKPOINT_PATH)
MIN_LOSS = loss.item()
use_cuda = torch.cuda.is_available()
torch.manual_seed(7)
device = torch.device('cuda' if use_cuda else 'cpu')
# device = 'cpu'
printout(f'Using {device}')
train_dataset = CommandDataset(csv_file=CSV_FILE, root_dir=ROOT_DIR)
train_loader = DataLoader(dataset=train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=lambda x: preprocess_model(x, 'train'))
model = SpeechRecognitionModel().to(device)
if os.path.isfile(MODEL_PATH):
printout('Saved model checkpoint found, loading it...')
model.load_state_dict(torch.load(MODEL_PATH))
optimizer = optim.AdamW(model.parameters(), MAX_LEARNING_RATE)
criterion = nn.CTCLoss(blank=28).to(device)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer,
max_lr=MAX_LEARNING_RATE,
steps_per_epoch=len(train_loader),
epochs=EPOCHS,
anneal_strategy='linear')
printout('Training...')
for epoch in tqdm(range(1, EPOCHS + 1)):
train(model, device, train_loader, criterion, optimizer, scheduler, epoch)
printout(f'Saving trained model with the least loss of {MIN_LOSS}...')
model.load_state_dict(torch.load(CHECKPOINT_PATH))
torch.save(model.state_dict(), 'command_model_trained.pth')