-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathdatasets.py
More file actions
187 lines (152 loc) · 7.03 KB
/
datasets.py
File metadata and controls
187 lines (152 loc) · 7.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import os, errno
from scipy import misc
import numpy as np
class Dataset():
"""
Contains the dataset to be used in meta-learning
self.items contains dictionaries of class_index: [item_tuple, item_tuple, ...] elements
self.idx_classes contains a dictionary of class_name: class_index elements
"""
def __init__(self, name, items, idx_classes, get_data: callable, single_example_size):
self.name = name
self.items = items
self.idx_classes = idx_classes
self.get_data = get_data
self.single_example_size = single_example_size
def __getitem__(self, item):
return self.items[item]
def n_classes(self):
return len(self.idx_classes)
def classes(self):
"""returns a list containing all the classes names"""
return list(self.idx_classes.keys())
@staticmethod
def union(d1, d2, u_name):
if not isinstance(d1, Dataset) or not isinstance(d2, Dataset):
raise TypeError('d1 and d2 must be both Datasets')
if not d1.parent == d2.parent:
raise ValueError('d1 and d2 must have the same parent')
u_parent = d1.parent
u_items = d1.items.copy()
u_idx_classes = d1.idx_classes.copy()
for class_d2, items_d2 in d2.items.items():
if class_d2 not in u_idx_classes.keys():
u_idx_classes[class_d2] = len(u_idx_classes)
u_items[u_idx_classes[class_d2]] = items_d2
else:
for item in items_d2:
if item not in u_items[u_idx_classes[class_d2]]:
u_items[u_idx_classes[class_d2]].append(item)
u_dataset = Dataset(u_name, '', lambda x: x, u_parent )
u_dataset.items = u_items
u_dataset.idx_classes = u_idx_classes
return u_dataset
class Omniglot():
urls = [
'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip',
'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip'
]
raw_folder = 'raw'
processed_folder = 'processed'
'''
The items inside self.train and self.test are dict of class_name: [(file_path, class_index, rotation), ...] elements
Args:
- root: the directory where the dataset will be stored
- download: need to download the dataset
- rotations: array of rotation [0, 1, 2, 3] contains rotations of 0, 90, 180, 270 degrees
- split: value of training classes without considering rotations, if none folder splitting will be used
'''
def __init__(self, root, download=False, rotations=None, split=None, example_size=(105, 105, 1)):
print('Loading Omniglot with rotations: {}, split: {}, ex_size: {}'.format(rotations, split, example_size))
self.root = root
if not self._check_exists():
if download:
self.download()
else:
raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')
self.split = split
self.rotations = rotations
if self.rotations is None:
self.rotations = [0]
train_item, train_idx, test_item, test_idx = self.find_items_and_split(os.path.join(self.root,
self.processed_folder))
self.train = Dataset('train', train_item, train_idx, self.get_data, example_size)
self.test = Dataset('test', test_item, test_idx, self.get_data, example_size)
def download(self):
"""
download files from url into raw folder and then decompress into processed folder.
:return:
"""
from six.moves import urllib
import zipfile
if self._check_exists():
return
# download files
try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
for url in self.urls:
print('>>Downloading ' + url)
data = urllib.request.urlopen(url)
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
with open(file_path, 'wb') as f:
f.write(data.read())
file_processed = os.path.join(self.root, self.processed_folder)
print(">>Unzip from " + file_path + " to " + file_processed)
zip_ref = zipfile.ZipFile(file_path, 'r')
zip_ref.extractall(file_processed)
zip_ref.close()
print("<<Download finished.")
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \
os.path.exists(os.path.join(self.root, self.processed_folder, "images_background"))
def find_items_and_split(self, root_dir):
train_items = {}
train_idx_classes = {}
test_items = {}
test_idx_classes = {}
idx_classes = train_idx_classes
items = train_items
cur_split = (self.split if self.split is not None else 964) * len(self.rotations)
for (root, dirs, files) in os.walk(root_dir):
for f in files:
if f.endswith("png"):
path_array = root.split('/')
class_name = path_array[-2] + "/" + path_array[-1]
if not self._check_class(class_name, idx_classes):
if len(idx_classes) > cur_split - 1:
idx_classes = test_idx_classes
items = test_items
cur_split = float('Inf')
self._add_class(idx_classes, items, class_name)
self._add_item(idx_classes, items, class_name, os.path.join(root, f))
print("Classes Found: [%d, %d] ([train, test]) " % (len(train_idx_classes), len(test_idx_classes)))
return train_items, train_idx_classes, test_items, test_idx_classes
def _check_class(self, class_name, idx_classes):
r_class_name = class_name + str(self.rotations[0])
if r_class_name in idx_classes.keys():
return True
return False
def _add_class(self, idx_classes, items, class_name):
for r in self.rotations:
r_class_name = class_name + str(r)
idx_classes[r_class_name] = len(idx_classes)
items[idx_classes[r_class_name]] = []
def _add_item(self, idx_classes, items, class_name, item_path):
for r in self.rotations:
r_class_name = class_name + str(r)
items[idx_classes[r_class_name]].append((item_path, idx_classes[r_class_name], r))
@staticmethod
def get_data(item):
"""read image from item path and applies rotation"""
img = np.array(misc.imread(item[0]))
img = np.rot90(img, int(item[2]))
return img
if __name__ == '__main__':
o = Omniglot(root='omniglot', download=True)