-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
135 lines (109 loc) · 4.26 KB
/
utils.py
File metadata and controls
135 lines (109 loc) · 4.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import sys
import time
import torch
import torch.nn as nn
from typing import Dict
from datasets.dataset import CDDataset
from torch.utils.data import DataLoader
# from models.multi_conv import Multi_conv
from models.BIT import define_G, BASE_Transformer
from models.FPT import FPT
from models.FPT_6in1 import FPT_6in1
from models.Lapsrn import Lapsrn
from models.multi_conv import Multi_conv
from losses.crossentropy import cross_entropy
def get_loader(config: Dict, type: str) -> DataLoader:
if type == 'train':
dataset = CDDataset(config['train_dir'])
dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True, num_workers=4)
elif type == 'val':
dataset = CDDataset(config['val_dir'])
dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=False, num_workers=4)
elif type == 'test':
dataset = CDDataset(config['test_dir'])
dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=False, num_workers=4)
else:
raise TypeError(f"{type} is invalid, shoule be 'train', 'val' or 'test'")
return dataloader
def get_model(config: Dict) -> nn.Module:
if config['model'] == 'Lapsrn':
model = Lapsrn()
elif config['model'] == 'Multi_conv':
model = Multi_conv()
elif config['model'] == 'BIT':
model = BASE_Transformer(input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4,
with_pos='learned', enc_depth=1, dec_depth=8)
elif config['model'] == 'FPT':
model = FPT()
elif config['model'] == 'FPT_6in1':
model = FPT_6in1()
else:
raise NotImplementedError(f"{config['model']} is not implemented")
return model
def get_optimizer(model: nn.Module, config: Dict):
if config['optimizer'] == 'Adam':
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
elif config['optimizer'] == 'SGD':
optimizer = torch.optim.SGD(model.parameters(), lr=config['lr'], momentum=0.9, weight_decay=5e-4)
elif config['optimizer'] == 'AdamW':
optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'])
else:
raise NotImplementedError(f"{config['optimizer']} is not implemented")
return optimizer
def get_loss_fn(config:Dict):
if config['loss_fn'] == 'CrossEntropy':
loss_fn = nn.CrossEntropyLoss()
else:
raise NotImplementedError(f"{config['loss_fn']} is not implemented")
return loss_fn
class Timer:
def __init__(self, starting_msg = None):
self.start = time.time()
self.stage_start = self.start
if starting_msg is not None:
print(starting_msg, time.ctime(time.time()))
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return
def update_progress(self, progress):
self.elapsed = time.time() - self.start
self.est_total = self.elapsed / progress
self.est_remaining = self.est_total - self.elapsed
self.est_finish = int(self.start + self.est_total)
def str_estimated_complete(self):
return str(time.ctime(self.est_finish))
def str_estimated_remaining(self):
return str(self.est_remaining/3600) + 'h'
def estimated_remaining(self):
return self.est_remaining/3600
def get_stage_elapsed(self):
return time.time() - self.stage_start
def reset_stage(self):
self.stage_start = time.time()
def lapse(self):
out = time.time() - self.stage_start
self.stage_start = time.time()
return out
class Logger(object):
def __init__(self, outfile):
self.terminal = sys.stdout
self.log_path = outfile
now = time.strftime("%c")
self.write('================ (%s) ================\n' % now)
def write(self, message):
self.terminal.write(message)
with open(self.log_path, mode='a') as f:
f.write(message)
def write_dict(self, dict):
message = ''
for k, v in dict.items():
message += '%s: %.7f ' % (k, v)
self.write(message)
def write_dict_str(self, dict):
message = ''
for k, v in dict.items():
message += '%s: %s ' % (k, v)
self.write(message)
def flush(self):
self.terminal.flush()