-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.py
More file actions
109 lines (89 loc) · 4.34 KB
/
data.py
File metadata and controls
109 lines (89 loc) · 4.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/usr/bin/env python3
import torch
import cv2
import numpy as np
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
TRAIN_ROOT_PATH = "/home/eragon/Desktop/Datasets/Wheat/train"
# Load dataset
class DatasetRetriever(Dataset):
def __init__(self, marking, image_ids, transforms=None, test=False):
super().__init__()
self.image_ids = image_ids
self.marking = marking
self.transforms = transforms
self.test = test
def __getitem__(self, index: int):
image_id = self.image_ids[index]
if self.test or random.random() > 0.5:
image, boxes = self.load_image_and_boxes(index)
else:
image, boxes = self.load_cutmix_image_and_boxes(index)
# there is only one class
labels = torch.ones((boxes.shape[0],), dtype=torch.int64)
target = {}
target['boxes'] = boxes
target['labels'] = labels
target['image_id'] = torch.tensor([index])
if self.transforms:
for i in range(10):
sample = self.transforms(**{
'image': image,
'bboxes': target['boxes'],
'labels': labels
})
if len(sample['bboxes']) > 0:
image = sample['image']
target['boxes'] = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)
target['boxes'][:,[0,1,2,3]] = target['boxes'][:,[1,0,3,2]] #yxyx: be warning
break
return image, target, image_id
def __len__(self) -> int:
return self.image_ids.shape[0]
def load_image_and_boxes(self, index):
# Get image and corresponding box from data
image_id = self.image_ids[index]
image = cv2.imread(f"{TRAIN_ROOT_PATH}/{image_id}.jpg", cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
image /= 255.0
records = self.marking[self.marking['image_id'] == image_id]
boxes = records[['x', 'y', 'w', 'h']].values
boxes[:,2] = boxes[:, 0] + boxes[:, 2]
boxes[:,3] = boxes[:, 1] + boxes[:, 3]
return image, boxes
# This implementation of cutmix author: https://www.kaggle.com/nvnnghia
# https://www.kaggle.com/shonenkov
def load_cutmix_image_and_boxes(self, index, imsize=1024):
w, h = imsize, imsize
s = imsize // 2
xc, yc = [int(random.uniform(imsize * 0.25, imsize * 0.75)) for _ in range(2)] # center x, y
indexes = [index] + [random.randint(0, self.image_ids.shape[0] - 1) for _ in range(3)]
result_image = np.full((imsize, imsize, 3), 1, dtype=np.float32)
result_boxes = []
for i, index in enumerate(indexes):
image, boxes = self.load_image_and_boxes(index)
if i == 0:
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
elif i == 1: # top right
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
elif i == 2: # bottom left
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, max(xc, w), min(y2a - y1a, h)
elif i == 3: # bottom right
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
result_image[y1a:y2a, x1a:x2a] = image[y1b:y2b, x1b:x2b]
padw = x1a - x1b
padh = y1a - y1b
boxes[:, 0] += padw
boxes[:, 1] += padh
boxes[:, 2] += padw
boxes[:, 3] += padh
result_boxes.append(boxes)
result_boxes = np.concatenate(result_boxes, 0)
np.clip(result_boxes[:, 0:], 0, 2 * s, out=result_boxes[:, 0:])
result_boxes = result_boxes.astype(np.int32)
result_boxes = result_boxes[np.where((result_boxes[:,2]-result_boxes[:,0])*(result_boxes[:,3]-result_boxes[:,1]) > 0)]
return result_image, result_boxes