-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathOverfittingTest.py
More file actions
65 lines (62 loc) · 1.91 KB
/
OverfittingTest.py
File metadata and controls
65 lines (62 loc) · 1.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import numpy as np
import PIL.Image as Image
import os
import Testing as testNet
import Network as net
#####################Function############################
def printImage(img):
img = img.detach().cpu()
img = torchvision.utils.make_grid(img)
img = np.transpose(img, (1,2,0))
img = img*0.5 + 0.5
plt.imshow(img)
plt.show()
#################Hyper Parameter#########################
img_path ='C:/Datasets/OverfittingTest'
fname='1803151818-00000048.jpg'
fname2 = '1803151818-00000048.png'
#net = testNet.BasicNet()
net = testNet.TestNet_Pool()
#####################Etc#################################
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5, ))
])
l1_loss = nn.L1Loss()
if __name__ =='__main__':
img = Image.open(os.path.join(img_path, fname))
img = img.resize((256, 256))
gt = Image.open((os.path.join(img_path, fname2))).resize((256, 256))#img.copy()
img = transform(img).to(device)
img = img.unsqueeze(0)
gt = transform(gt).to(device)
gt = gt.unsqueeze(0)
print(gt.shape)
printImage(gt)
net.to(device)
optimizer = optim.Adam(params=net.parameters(), lr=0.002)
epoch = 1000
for e in range(epoch):
result, temp = net(img)
loss = l1_loss(result, gt)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if e%100 ==99:
print("e: %d, loss: %f"%(e, loss))
printImage(img)
printImage(temp)
printImage(result)
#print(l1_loss(img, gt))