From a0a84ef6b024eecefcaef76a4aac98b9081cb304 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Sat, 26 Aug 2017 23:07:11 -0700 Subject: [PATCH 1/2] add fashion mnist and move mnists to s3 --- python/mxnet/gluon/data/vision.py | 56 ++++++++++++++++++++++-- tests/python/unittest/test_gluon_data.py | 1 + 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/python/mxnet/gluon/data/vision.py b/python/mxnet/gluon/data/vision.py index b63624508124..85a1f9064b3d 100644 --- a/python/mxnet/gluon/data/vision.py +++ b/python/mxnet/gluon/data/vision.py @@ -70,14 +70,14 @@ class MNIST(_DownloadedDataset): transform=lambda data, label: (data.astype(np.float32)/255, label) """ - def __init__(self, root='~/.mxnet/datasets/', train=True, + def __init__(self, root='~/.mxnet/datasets/mnist', train=True, transform=None): super(MNIST, self).__init__(root, train, transform) def _get_data(self): if not os.path.isdir(self._root): os.makedirs(self._root) - url = 'http://data.mxnet.io/data/mnist/' + url = 'https://apache-mxnet.s3.amazonaws.com/gluon/dataset/mnist/' if self._train: data_file = download(url+'train-images-idx3-ubyte.gz', self._root, sha1_hash='6c95f4b05d2bf285e1bfb0e7960c31bd3b3f8a7d') @@ -102,6 +102,56 @@ def _get_data(self): self._label = label +class FashionMNIST(_DownloadedDataset): + """A dataset of Zalando's article images consisting of fashion products, + a drop-in replacement of the original MNIST dataset from + `https://github.com/zalandoresearch/fashion-mnist`_. + + Each sample is an image (in 3D NDArray) with shape (28, 28, 1). + + Parameters + ---------- + root : str + Path to temp folder for storing data. + train : bool + Whether to load the training or testing set. + transform : function + A user defined callback that transforms each instance. For example:: + + transform=lambda data, label: (data.astype(np.float32)/255, label) + """ + def __init__(self, root='~/.mxnet/datasets/fashion-mnist', train=True, + transform=None): + super(FashionMNIST, self).__init__(root, train, transform) + + def _get_data(self): + if not os.path.isdir(self._root): + os.makedirs(self._root) + url = 'https://apache-mxnet.s3.amazonaws.com/gluon/dataset/fashion-mnist/' + if self._train: + data_file = download(url+'train-images-idx3-ubyte.gz', self._root, + sha1_hash='0cf37b0d40ed5169c6b3aba31069a9770ac9043d') + label_file = download(url+'train-labels-idx1-ubyte.gz', self._root, + sha1_hash='236021d52f1e40852b06a4c3008d8de8aef1e40b') + else: + data_file = download(url+'t10k-images-idx3-ubyte.gz', self._root, + sha1_hash='626ed6a7c06dd17c0eec72fa3be1740f146a2863') + label_file = download(url+'t10k-labels-idx1-ubyte.gz', self._root, + sha1_hash='17f9ab60e7257a1620f4ad76bbbaf857c3920701') + + with gzip.open(label_file, 'rb') as fin: + struct.unpack(">II", fin.read(8)) + label = np.fromstring(fin.read(), dtype=np.uint8).astype(np.int32) + + with gzip.open(data_file, 'rb') as fin: + struct.unpack(">IIII", fin.read(16)) + data = np.fromstring(fin.read(), dtype=np.uint8) + data = data.reshape(len(label), 28, 28, 1) + + self._data = [nd.array(x, dtype=x.dtype) for x in data] + self._label = label + + class CIFAR10(_DownloadedDataset): """CIFAR10 image classification dataset from `https://www.cs.toronto.edu/~kriz/cifar.html`_. @@ -118,7 +168,7 @@ class CIFAR10(_DownloadedDataset): transform=lambda data, label: (data.astype(np.float32)/255, label) """ - def __init__(self, root='~/.mxnet/datasets/', train=True, + def __init__(self, root='~/.mxnet/datasets/cifar10', train=True, transform=None): self._file_hashes = {'data_batch_1.bin': 'aadd24acce27caa71bf4b10992e9e7b2d74c2540', 'data_batch_2.bin': 'c0ba65cce70568cd57b4e03e9ac8d2a5367c1795', diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index 32298fcd57d5..7f388be73cb3 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -71,6 +71,7 @@ def test_sampler(): def test_datasets(): assert len(gluon.data.vision.MNIST(root='data')) == 60000 + assert len(gluon.data.vision.FashionMNIST(root='data')) == 60000 assert len(gluon.data.vision.CIFAR10(root='data', train=False)) == 10000 def test_image_folder_dataset(): From 8ea14857c564a0563dbb596f44951396905ab190 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Sun, 27 Aug 2017 15:15:26 -0700 Subject: [PATCH 2/2] refactor --- python/mxnet/gluon/data/vision.py | 70 +++++++++++++------------------ 1 file changed, 28 insertions(+), 42 deletions(-) diff --git a/python/mxnet/gluon/data/vision.py b/python/mxnet/gluon/data/vision.py index 85a1f9064b3d..24c060c54c84 100644 --- a/python/mxnet/gluon/data/vision.py +++ b/python/mxnet/gluon/data/vision.py @@ -40,6 +40,8 @@ def __init__(self, root, train, transform): self._data = None self._label = None + if not os.path.isdir(self._root): + os.makedirs(self._root) self._get_data() def __getitem__(self, idx): @@ -72,22 +74,27 @@ class MNIST(_DownloadedDataset): """ def __init__(self, root='~/.mxnet/datasets/mnist', train=True, transform=None): + self._base_url = 'https://apache-mxnet.s3.amazonaws.com/gluon/dataset/mnist/' + self._train_data = ('train-images-idx3-ubyte.gz', + '6c95f4b05d2bf285e1bfb0e7960c31bd3b3f8a7d') + self._train_label = ('train-labels-idx1-ubyte.gz', + '2a80914081dc54586dbdf242f9805a6b8d2a15fc') + self._test_data = ('t10k-images-idx3-ubyte.gz', + 'c3a25af1f52dad7f726cce8cacb138654b760d48') + self._test_label = ('t10k-labels-idx1-ubyte.gz', + '763e7fa3757d93b0cdec073cef058b2004252c17') super(MNIST, self).__init__(root, train, transform) def _get_data(self): - if not os.path.isdir(self._root): - os.makedirs(self._root) - url = 'https://apache-mxnet.s3.amazonaws.com/gluon/dataset/mnist/' if self._train: - data_file = download(url+'train-images-idx3-ubyte.gz', self._root, - sha1_hash='6c95f4b05d2bf285e1bfb0e7960c31bd3b3f8a7d') - label_file = download(url+'train-labels-idx1-ubyte.gz', self._root, - sha1_hash='2a80914081dc54586dbdf242f9805a6b8d2a15fc') + data, label = self._train_data, self._train_label else: - data_file = download(url+'t10k-images-idx3-ubyte.gz', self._root, - sha1_hash='c3a25af1f52dad7f726cce8cacb138654b760d48') - label_file = download(url+'t10k-labels-idx1-ubyte.gz', self._root, - sha1_hash='763e7fa3757d93b0cdec073cef058b2004252c17') + data, label = self._test_data, self._test_label + + data_file = download(self._base_url + data[0], self._root, + sha1_hash=data[1]) + label_file = download(self._base_url + label[0], self._root, + sha1_hash=label[1]) with gzip.open(label_file, 'rb') as fin: struct.unpack(">II", fin.read(8)) @@ -102,7 +109,7 @@ def _get_data(self): self._label = label -class FashionMNIST(_DownloadedDataset): +class FashionMNIST(MNIST): """A dataset of Zalando's article images consisting of fashion products, a drop-in replacement of the original MNIST dataset from `https://github.com/zalandoresearch/fashion-mnist`_. @@ -122,35 +129,17 @@ class FashionMNIST(_DownloadedDataset): """ def __init__(self, root='~/.mxnet/datasets/fashion-mnist', train=True, transform=None): + self._base_url = 'https://apache-mxnet.s3.amazonaws.com/gluon/dataset/fashion-mnist/' + self._train_data = ('train-images-idx3-ubyte.gz', + '0cf37b0d40ed5169c6b3aba31069a9770ac9043d') + self._train_label = ('train-labels-idx1-ubyte.gz', + '236021d52f1e40852b06a4c3008d8de8aef1e40b') + self._test_data = ('t10k-images-idx3-ubyte.gz', + '626ed6a7c06dd17c0eec72fa3be1740f146a2863') + self._test_label = ('t10k-labels-idx1-ubyte.gz', + '17f9ab60e7257a1620f4ad76bbbaf857c3920701') super(FashionMNIST, self).__init__(root, train, transform) - def _get_data(self): - if not os.path.isdir(self._root): - os.makedirs(self._root) - url = 'https://apache-mxnet.s3.amazonaws.com/gluon/dataset/fashion-mnist/' - if self._train: - data_file = download(url+'train-images-idx3-ubyte.gz', self._root, - sha1_hash='0cf37b0d40ed5169c6b3aba31069a9770ac9043d') - label_file = download(url+'train-labels-idx1-ubyte.gz', self._root, - sha1_hash='236021d52f1e40852b06a4c3008d8de8aef1e40b') - else: - data_file = download(url+'t10k-images-idx3-ubyte.gz', self._root, - sha1_hash='626ed6a7c06dd17c0eec72fa3be1740f146a2863') - label_file = download(url+'t10k-labels-idx1-ubyte.gz', self._root, - sha1_hash='17f9ab60e7257a1620f4ad76bbbaf857c3920701') - - with gzip.open(label_file, 'rb') as fin: - struct.unpack(">II", fin.read(8)) - label = np.fromstring(fin.read(), dtype=np.uint8).astype(np.int32) - - with gzip.open(data_file, 'rb') as fin: - struct.unpack(">IIII", fin.read(16)) - data = np.fromstring(fin.read(), dtype=np.uint8) - data = data.reshape(len(label), 28, 28, 1) - - self._data = [nd.array(x, dtype=x.dtype) for x in data] - self._label = label - class CIFAR10(_DownloadedDataset): """CIFAR10 image classification dataset from `https://www.cs.toronto.edu/~kriz/cifar.html`_. @@ -186,9 +175,6 @@ def _read_batch(self, filename): data[:, 0].astype(np.int32) def _get_data(self): - if not os.path.isdir(self._root): - os.makedirs(self._root) - file_paths = [(name, os.path.join(self._root, 'cifar-10-batches-bin/', name)) for name in self._file_hashes] if any(not os.path.exists(path) or not check_sha1(path, self._file_hashes[name])