-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
121 lines (105 loc) · 4.65 KB
/
train.py
File metadata and controls
121 lines (105 loc) · 4.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import json
import torch
import argparse
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from modules.constants import feature_list
from modules.utils import LossCallback
from modules.data import GapFillingDataset
from modules.gapt import GapT
from modules.baseline import Baseline
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='gapt', choices=['gapt', 'baseline'])
parser.add_argument('--mode', type=str, default='default', choices=['default', 'naive'])
parser.add_argument('--encoder_type', type=str, default='transformer', choices=['transformer', 'lru', 'lstm', 'gru', 'tcn', 'mlp'])
parser.add_argument('--time_encoding', type=str, default=None, choices=['time2vec', 'periodic'])
parser.add_argument('--d_output', type=int, default=1)
parser.add_argument('--n_head', type=int, default=4)
parser.add_argument('--d_model', type=int, default=64)
parser.add_argument('--d_embedding', type=int, default=64)
parser.add_argument('--n_layers', type=int, default=2)
parser.add_argument('--kernel_size', type=int, default=3)
parser.add_argument('--d_hidden', type=int, default=128)
parser.add_argument('--enc_dropout', type=float, default=0.0)
parser.add_argument('--ff_dropout', type=float, default=0.2)
parser.add_argument('--learning_rate', type=float, default=0.1)
parser.add_argument('--optimizer', type=str, default='momo', choices=['momo', 'adam'])
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--epochs', type=int, default=1)
parser.add_argument('--data_dir', type=str)
parser.add_argument('--output_dir', type=str)
parser.add_argument('--devices', default=1, type=int)
parser.add_argument('--num_workers', default=16, type=int)
args = parser.parse_args()
for arg in vars(args):
print(f'{arg}: {getattr(args, arg)}')
# Create the datasets
with open(os.path.join(args.data_dir, 'paths.json'), 'r') as f:
data_paths = json.load(f)
train_dataset = GapFillingDataset(data_paths['train'], feature_list)
val_dataset = GapFillingDataset(data_paths['val'], feature_list)
test_dataset = GapFillingDataset(data_paths['test'], feature_list)
# Create the dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
# Initialize the model
d_input = len(feature_list)
if args.model == 'gapt':
model = GapT(
encoder_type=args.encoder_type,
mode=args.mode,
time_encoding=args.time_encoding,
d_input=d_input,
d_model=args.d_model,
n_head=args.n_head,
kernel_size=args.kernel_size,
d_hidden=args.d_hidden,
n_layers=args.n_layers,
d_output=args.d_output,
enc_dropout=args.enc_dropout,
ff_dropout=args.ff_dropout,
learning_rate=args.learning_rate,
optimizer=args.optimizer,
)
elif args.model == 'baseline':
model = Baseline(
d_input=d_input,
d_embedding=args.d_embedding,
d_model=args.d_model,
d_output=args.d_output,
learning_rate=args.learning_rate,
dropout_rate=args.ff_dropout,
optimizer=args.optimizer,
)
else:
raise ValueError(f'Invalid model type: {args.model}')
# Store loss
loss_callback = LossCallback()
# Initialize a trainer
trainer = pl.Trainer(
accelerator='auto',
devices=args.devices,
max_epochs=args.epochs,
log_every_n_steps=1,
logger=pl.loggers.TensorBoardLogger('logs/'),
callbacks=[loss_callback],
)
# Train the model
trainer.fit(model, train_dataloader, val_dataloader)
# Test the model
metrics = trainer.test(dataloaders=[train_dataloader, val_dataloader, test_dataloader])
# Save the model
trainer.save_checkpoint(os.path.join(args.output_dir, 'model.ckpt'))
# Save metadata
metadata = {
'args': vars(args),
'feature_list': feature_list,
'metrics': metrics,
'params': sum(p.numel() for p in model.parameters() if p.requires_grad),
'epoch_losses': loss_callback.losses,
}
with open(os.path.join(args.output_dir, 'metadata.json'), 'w') as f:
json.dump(metadata, f, indent=2)