-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
98 lines (77 loc) · 3.78 KB
/
train.py
File metadata and controls
98 lines (77 loc) · 3.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import random
from tokens import *
from model import *
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence#, masked_cross_entropy
from masked_cross_entropy import *
# Pad a with the PAD symbol
def pad_seq(seq, max_length):
seq += [PAD_token for i in range(max_length - len(seq))]
return seq
# Return a list of indexes, one for each word in the sentence, plus EOS
def indexes_from_sentence(lang, sentence):
return [lang.get_index_from_word(word) for word in sentence.split(' ')] + [EOS_token]
def random_batch(input_lang, output_lang, batch_size, pairs, USE_CUDA=True):
input_seqs = []
target_seqs = []
# Choose random pairs
for i in range(batch_size):
pair = random.choice(pairs)
input_seqs.append(indexes_from_sentence(input_lang, pair[0]))
target_seqs.append(indexes_from_sentence(output_lang, pair[1]))
# Zip into pairs, sort by length (descending), unzip
seq_pairs = sorted(zip(input_seqs, target_seqs), key=lambda p: len(p[0]), reverse=True)
input_seqs, target_seqs = zip(*seq_pairs)
# For input and target sequences, get array of lengths and pad with 0s to max length
input_lengths = [len(s) for s in input_seqs]
input_padded = [pad_seq(s, max(input_lengths)) for s in input_seqs]
target_lengths = [len(s) for s in target_seqs]
target_padded = [pad_seq(s, max(target_lengths)) for s in target_seqs]
# Turn padded arrays into (batch_size x max_len) tensors, transpose into (max_len x batch_size)
input_var = Variable(torch.LongTensor(input_padded)).transpose(0, 1)
target_var = Variable(torch.LongTensor(target_padded)).transpose(0, 1)
if USE_CUDA:
input_var = input_var.cuda()
target_var = target_var.cuda()
return input_var, input_lengths, target_var, target_lengths
def train(input_batches, input_lengths, target_batches, target_lengths, batch_size, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, clip, max_length, USE_CUDA=True):
# Zero gradients of both optimizers
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()
loss = 0 # Added onto for each word
# Run words through encoder
encoder_outputs, encoder_hidden = encoder(input_batches, input_lengths, None)
# Prepare input and output variables
decoder_input = Variable(torch.LongTensor([SOS_token] * batch_size))
decoder_hidden = encoder_hidden[:decoder.n_layers] # Use last (forward) hidden state from encoder
max_target_length = max(target_lengths)
all_decoder_outputs = Variable(torch.zeros(max_target_length, batch_size, decoder.output_size))
# Move new Variables to CUDA
if USE_CUDA:
decoder_input = decoder_input.cuda()
all_decoder_outputs = all_decoder_outputs.cuda()
# Run through decoder one time step at a time
for t in range(max_target_length):
decoder_output, decoder_hidden, decoder_attn = decoder(
decoder_input, decoder_hidden, encoder_outputs, USE_CUDA
)
all_decoder_outputs[t] = decoder_output
decoder_input = target_batches[t] # Next input is current target
# Loss calculation and backpropagation
loss = masked_cross_entropy(
all_decoder_outputs.transpose(0, 1).contiguous(), # -> batch x seq
target_batches.transpose(0, 1).contiguous(), # -> batch x seq
target_lengths
)
loss.backward()
# Clip gradient norms
ec = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip)
dc = torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip)
# Update parameters with optimizers
encoder_optimizer.step()
decoder_optimizer.step()
return loss.item(), ec, dc