From 6e625c2327809e63cfeb30c221efb0d98b0b82e0 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Thu, 8 Nov 2018 18:27:14 -0800 Subject: [PATCH 01/12] initial --- python/mxnet/gluon/data/dataloader.py | 30 ++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 86cb835f5128..dc7b820c602a 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -252,6 +252,35 @@ def __len__(self): def __del__(self): self.shutdown() + def reset(self): + """Reset iterator""" + # clear key queue + removed_idx = [] + while True: + try: + idx, _ = self._key_queue.get(False) + removed_idx.append(idx) + except Queue.Empty: + break + + # clear data queue + while self._rcvd_idx < self._sent_idx: + if self._rcvd_idx in removed_idx: + self._rcvd_idx += 1 + elif self._rcvd_idx in self._data_buffer: + _ = self._data_buffer.pop(self._rcvd_idx) + self._rcvd_idx += 1 + assert not self._data_buffer, "data buffer should be empty" + + # reset indices and samples + self._rcvd_idx = 0 + self._sent_idx = 0 + self._iter = iter(self._batch_sampler) + + # pre-fetch + for _ in range(2 * self._num_workers): + self._push_next() + def _push_next(self): """Assign next batch workload to workers.""" r = next(self._iter, None) @@ -264,7 +293,6 @@ def __next__(self): assert not self._shutdown, "call __next__ after shutdown is forbidden" if self._rcvd_idx == self._sent_idx: assert not self._data_buffer, "Data buffer should be empty at this moment" - self.shutdown() raise StopIteration while True: From 854a797672b794a44c4c5952df8d4ba17c9a3a6d Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Fri, 9 Nov 2018 11:47:58 -0800 Subject: [PATCH 02/12] wip --- python/mxnet/gluon/data/dataloader.py | 39 ++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index dc7b820c602a..a210889222d0 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -29,6 +29,11 @@ import threading import numpy as np +try: + from Queue import Empty as QueueEmpty +except ImportError: + from queue import Empty as QueueEmpty + try: import multiprocessing.resource_sharer except ImportError: @@ -260,7 +265,7 @@ def reset(self): try: idx, _ = self._key_queue.get(False) removed_idx.append(idx) - except Queue.Empty: + except QueueEmpty: break # clear data queue @@ -327,6 +332,38 @@ def shutdown(self): self._shutdown = True +class _SameProcessIter(object): + def __init__(self, dataset, batchify_fn, batch_sampler, pin_memory=False): + self._dataset = dataset + self._batchify_fn = batchify_fn + self._batch_sampler = batch_sampler + self._pin_memory = pin_memory + self._idx = 0 + + def __len__(self): + return len(self._batch_sampler) + + def reset(self): + """Reset iterator""" + self._idx = 0 + + def __next__(self): + if self._idx == self.__len__(): + raise StopIteration + + batch = self._batch_sampler[self._idx] + ret = self._batchify_fn([self._dataset[idx] for idx in batch]) + if self._pin_memory: + ret = _as_in_context(ret, context.cpu_pinned()) + self._idx += 1 + return ret + + def next(self): + return self.__next__() + + def __iter__(self): + return self + class DataLoader(object): """Loads data from a dataset and returns mini-batches of data. From 24386ff573d9e7b8f9927f9a0b34d33abca8f459 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Fri, 9 Nov 2018 14:49:03 -0800 Subject: [PATCH 03/12] add unittest --- python/mxnet/gluon/data/dataloader.py | 76 +++++++++++++++++------- tests/python/unittest/test_gluon_data.py | 22 +++++++ 2 files changed, 78 insertions(+), 20 deletions(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index a210889222d0..6c15330b63c9 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -212,7 +212,28 @@ def fetcher_loop(data_queue, data_buffer, pin_memory=False, data_buffer_lock=Non class _MultiWorkerIter(object): - """Interal multi-worker iterator for DataLoader.""" + """Interal multi-worker iterator for DataLoader. + It allow reset() to reuse iterator with all workers alive. + + Parameters + ---------- + num_workers : int, default 0 + The number of multiprocessing workers to use for data preprocessing. + dataset : Dataset + Source dataset. Note that numpy and mxnet arrays can be directly used + as a Dataset. + batchify_fn : callable + Callback function to allow users to specify how to merge samples + into a batch. + batch_sampler : Sampler + A sampler that returns mini-batches. Do not specify batch_size, + shuffle, sampler, and last_batch if batch_sampler is specified. + pin_memory : boolean, default False + If ``True``, the dataloader will copy NDArrays into pinned memory + before returning them. Copying from CPU pinned memory to GPU is faster + than from normal CPU memory. + + """ def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False, worker_fn=worker_loop): assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers) @@ -258,7 +279,7 @@ def __del__(self): self.shutdown() def reset(self): - """Reset iterator""" + """Reset iterator with multiprocessing workers alive.""" # clear key queue removed_idx = [] while True: @@ -333,30 +354,49 @@ def shutdown(self): class _SameProcessIter(object): + """Same Process Iterator, which allow reset(). + + Parameters + ---------- + dataset : Dataset + Source dataset. Note that numpy and mxnet arrays can be directly used + as a Dataset. + batchify_fn : callable + Callback function to allow users to specify how to merge samples + into a batch. + batch_sampler : Sampler + A sampler that returns mini-batches. Do not specify batch_size, + shuffle, sampler, and last_batch if batch_sampler is specified. + pin_memory : boolean, default False + If ``True``, the dataloader will copy NDArrays into pinned memory + before returning them. Copying from CPU pinned memory to GPU is faster + than from normal CPU memory. + + """ def __init__(self, dataset, batchify_fn, batch_sampler, pin_memory=False): self._dataset = dataset self._batchify_fn = batchify_fn self._batch_sampler = batch_sampler + self._iter = iter(self._batch_sampler) self._pin_memory = pin_memory - self._idx = 0 def __len__(self): return len(self._batch_sampler) def reset(self): - """Reset iterator""" - self._idx = 0 + """Reset iterator.""" + self._iter = iter(self._batch_sampler) def __next__(self): - if self._idx == self.__len__(): + try: + batch = next(self._iter) + except StopIteration: raise StopIteration - - batch = self._batch_sampler[self._idx] - ret = self._batchify_fn([self._dataset[idx] for idx in batch]) - if self._pin_memory: - ret = _as_in_context(ret, context.cpu_pinned()) - self._idx += 1 - return ret + else: + ret = self._batchify_fn([self._dataset[idx] for idx in batch]) + if self._pin_memory: + ret = _as_in_context(ret, context.cpu_pinned()) + return ret def next(self): return self.__next__() @@ -364,6 +404,7 @@ def next(self): def __iter__(self): return self + class DataLoader(object): """Loads data from a dataset and returns mini-batches of data. @@ -446,13 +487,8 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, def __iter__(self): if self._num_workers == 0: - def same_process_iter(): - for batch in self._batch_sampler: - ret = self._batchify_fn([self._dataset[idx] for idx in batch]) - if self._pin_memory: - ret = _as_in_context(ret, context.cpu_pinned()) - yield ret - return same_process_iter() + return _SameProcessIter(self._dataset, self._batchify_fn, + self._batch_sampler, self._pin_memory) # multi-worker return _MultiWorkerIter(self._num_workers, self._dataset, diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index e4206095f9ba..cf2c0c497f82 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -244,6 +244,28 @@ def test_multi_worker_forked_data_loader(): for i, data in enumerate(loader): pass +@with_seed() +def test_cached_iterator_in_dataloader(): + class _DummyData(object): + def __len__(self): + return 100 + + def __getitem__(self, idx): + return idx + + data = _DummyData() + length = len(data) + expect = np.arange(length) + for num_worker in range(2, 4): + loader = DataLoader(data, batch_size=2, shuffle=False, num_workers=num_worker) + it = iter(loader) + it.reset() + out = [] + for i, batch in enumerate(it): + print(i, batch) + out.append(batch.asnumpy().flatten()) + np.testing.assert_allclose(np.concatenate(out), expect) + if __name__ == '__main__': import nose nose.runmodule() From 922fd77ca22cc8b042653422b724fb9d49ea5acf Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Fri, 9 Nov 2018 15:17:05 -0800 Subject: [PATCH 04/12] update test --- python/mxnet/gluon/data/dataloader.py | 4 ++-- tests/python/unittest/test_gluon_data.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 6c15330b63c9..aefc54ec98fc 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -281,11 +281,11 @@ def __del__(self): def reset(self): """Reset iterator with multiprocessing workers alive.""" # clear key queue - removed_idx = [] + removed_idx = set() while True: try: idx, _ = self._key_queue.get(False) - removed_idx.append(idx) + removed_idx.add(idx) except QueueEmpty: break diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index cf2c0c497f82..9d0db6547dfa 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -256,7 +256,7 @@ def __getitem__(self, idx): data = _DummyData() length = len(data) expect = np.arange(length) - for num_worker in range(2, 4): + for num_worker in range(0, 4): loader = DataLoader(data, batch_size=2, shuffle=False, num_workers=num_worker) it = iter(loader) it.reset() From 02ae9efe259b1ec3831174e25c0d46e418b00bc1 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Mon, 12 Nov 2018 11:24:00 -0800 Subject: [PATCH 05/12] use __iter__ --- python/mxnet/gluon/data/dataloader.py | 8 ++++---- tests/python/unittest/test_gluon_data.py | 18 +++++++++--------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index aefc54ec98fc..8fc7087ae014 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -268,9 +268,7 @@ def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory= self._fetcher.daemon = True self._fetcher.start() - # pre-fetch - for _ in range(2 * self._num_workers): - self._push_next() + self._reset() def __len__(self): return len(self._batch_sampler) @@ -278,8 +276,9 @@ def __len__(self): def __del__(self): self.shutdown() - def reset(self): + def _reset(self): """Reset iterator with multiprocessing workers alive.""" + assert not self._shutdown, "call reset after shutdown is forbidden" # clear key queue removed_idx = set() while True: @@ -333,6 +332,7 @@ def next(self): return self.__next__() def __iter__(self): + self._reset() return self def shutdown(self): diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index 9d0db6547dfa..77f733669e99 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -244,22 +244,22 @@ def test_multi_worker_forked_data_loader(): for i, data in enumerate(loader): pass -@with_seed() -def test_cached_iterator_in_dataloader(): - class _DummyData(object): - def __len__(self): - return 100 - def __getitem__(self, idx): - return idx +class _SequentialDummyData(object): + def __len__(self): + return 100 + + def __getitem__(self, idx): + return idx - data = _DummyData() +@with_seed() +def test_cached_iterator_in_dataloader(): + data = _SequentialDummyData() length = len(data) expect = np.arange(length) for num_worker in range(0, 4): loader = DataLoader(data, batch_size=2, shuffle=False, num_workers=num_worker) it = iter(loader) - it.reset() out = [] for i, batch in enumerate(it): print(i, batch) From 937b6fbd4680506ecc415cf2f7c93eb2de741ac4 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Wed, 14 Nov 2018 11:17:03 -0800 Subject: [PATCH 06/12] fix _SameProcessIter --- python/mxnet/gluon/data/dataloader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 8fc7087ae014..692c57a44220 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -377,13 +377,13 @@ def __init__(self, dataset, batchify_fn, batch_sampler, pin_memory=False): self._dataset = dataset self._batchify_fn = batchify_fn self._batch_sampler = batch_sampler - self._iter = iter(self._batch_sampler) self._pin_memory = pin_memory + self._reset() def __len__(self): return len(self._batch_sampler) - def reset(self): + def _reset(self): """Reset iterator.""" self._iter = iter(self._batch_sampler) @@ -402,6 +402,7 @@ def next(self): return self.__next__() def __iter__(self): + self._reset() return self From 1690b0d62df0bdf4702b139a14136822a285a0c9 Mon Sep 17 00:00:00 2001 From: Joshua Zhang Date: Sun, 25 Nov 2018 14:31:52 -0800 Subject: [PATCH 07/12] fix recordio.py --- python/mxnet/gluon/data/dataloader.py | 6 +++--- python/mxnet/recordio.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 692c57a44220..33877dc1ac77 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -183,9 +183,9 @@ def worker_loop(dataset, key_queue, data_queue, batchify_fn): # for a dataset with transform function, the depth of MXRecordIO is 1 # for a lazy transformer, the depth is 2 # for a user defined transformer, the depth is unknown, try a reasonable depth - limit = sys.getrecursionlimit() - max_recursion_depth = min(limit - 5, max(10, limit // 2)) - _recursive_fork_recordio(dataset, 0, max_recursion_depth) + # limit = sys.getrecursionlimit() + # max_recursion_depth = min(limit - 5, max(10, limit // 2)) + # _recursive_fork_recordio(dataset, 0, max_recursion_depth) while True: idx, samples = key_queue.get() diff --git a/python/mxnet/recordio.py b/python/mxnet/recordio.py index 2def141c9340..2397820677e9 100644 --- a/python/mxnet/recordio.py +++ b/python/mxnet/recordio.py @@ -18,6 +18,7 @@ """Read and write for the RecordIO data format.""" from __future__ import absolute_import from collections import namedtuple +from multiprocessing import current_process import ctypes import struct @@ -65,6 +66,7 @@ def __init__(self, uri, flag): self.uri = c_str(uri) self.handle = RecordIOHandle() self.flag = flag + self.pid = None self.is_open = False self.open() @@ -78,6 +80,7 @@ def open(self): self.writable = False else: raise ValueError("Invalid flag %s"%self.flag) + self.pid = current_process().pid self.is_open = True def __del__(self): @@ -118,6 +121,7 @@ def close(self): else: check_call(_LIB.MXRecordIOReaderFree(self.handle)) self.is_open = False + self.pid = None def reset(self): """Resets the pointer to first item. @@ -156,6 +160,8 @@ def write(self, buf): Buffer to write. """ assert self.writable + assert self.pid == current_process().pid, \ + "writing in different process is forbidden" check_call(_LIB.MXRecordIOWriterWriteRecord(self.handle, ctypes.c_char_p(buf), ctypes.c_size_t(len(buf)))) @@ -182,6 +188,10 @@ def read(self): Buffer read. """ assert not self.writable + if not self.pid == current_process().pid: + # in forked process, obtain a new handle + # print("PID not matching, reset") + self.reset() buf = ctypes.c_char_p() size = ctypes.c_size_t() check_call(_LIB.MXRecordIOReaderReadRecord(self.handle, From 871b427c1f2f989487cee4a52c897e673591efcb Mon Sep 17 00:00:00 2001 From: Joshua Zhang Date: Sun, 25 Nov 2018 15:31:38 -0800 Subject: [PATCH 08/12] fix forking behavior --- python/mxnet/gluon/data/dataloader.py | 21 --------------------- python/mxnet/recordio.py | 19 +++++++++++++------ 2 files changed, 13 insertions(+), 27 deletions(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 33877dc1ac77..88a3059bcf96 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -41,7 +41,6 @@ from . import sampler as _sampler from ... import nd, context -from ...recordio import MXRecordIO if sys.platform == 'darwin' or sys.platform == 'win32': def rebuild_ndarray(*args): @@ -164,29 +163,9 @@ def _as_in_context(data, ctx): return [_as_in_context(d, ctx) for d in data] return data -def _recursive_fork_recordio(obj, depth, max_depth=1000): - """Recursively find instance of MXRecordIO and reset file handler. - This is required for MXRecordIO which holds a C pointer to a opened file after fork. - """ - if depth >= max_depth: - return - if isinstance(obj, MXRecordIO): - obj.close() - obj.open() # re-obtain file hanlder in new process - elif (hasattr(obj, '__dict__')): - for _, v in obj.__dict__.items(): - _recursive_fork_recordio(v, depth + 1, max_depth) def worker_loop(dataset, key_queue, data_queue, batchify_fn): """Worker loop for multiprocessing DataLoader.""" - # re-fork a new recordio handler in new process if applicable - # for a dataset with transform function, the depth of MXRecordIO is 1 - # for a lazy transformer, the depth is 2 - # for a user defined transformer, the depth is unknown, try a reasonable depth - # limit = sys.getrecursionlimit() - # max_recursion_depth = min(limit - 5, max(10, limit // 2)) - # _recursive_fork_recordio(dataset, 0, max_recursion_depth) - while True: idx, samples = key_queue.get() if idx is None: diff --git a/python/mxnet/recordio.py b/python/mxnet/recordio.py index 2397820677e9..bdc63235d702 100644 --- a/python/mxnet/recordio.py +++ b/python/mxnet/recordio.py @@ -112,6 +112,14 @@ def __setstate__(self, d): if is_open: self.open() + def _check_pid(self, allow_reset=False): + """Check process id to ensure integrity, reset if in new process.""" + if not self.pid == current_process().pid: + if allow_reset: + self.reset() + else: + raise RuntimeError("Forbidden operation in multiple processes") + def close(self): """Closes the record file.""" if not self.is_open: @@ -160,8 +168,7 @@ def write(self, buf): Buffer to write. """ assert self.writable - assert self.pid == current_process().pid, \ - "writing in different process is forbidden" + self._check_pid(allow_reset=False) check_call(_LIB.MXRecordIOWriterWriteRecord(self.handle, ctypes.c_char_p(buf), ctypes.c_size_t(len(buf)))) @@ -188,10 +195,9 @@ def read(self): Buffer read. """ assert not self.writable - if not self.pid == current_process().pid: - # in forked process, obtain a new handle - # print("PID not matching, reset") - self.reset() + # trying to implicitly read from multiple processes is forbidden, + # there's no elegant way to handle unless lock is introduced + self._check_pid(allow_reset=False) buf = ctypes.c_char_p() size = ctypes.c_size_t() check_call(_LIB.MXRecordIOReaderReadRecord(self.handle, @@ -265,6 +271,7 @@ def seek(self, idx): This function is internally called by `read_idx(idx)` to find the current reader pointer position. It doesn't return anything.""" assert not self.writable + self._check_pid(allow_reset=True) pos = ctypes.c_size_t(self.idx[idx]) check_call(_LIB.MXRecordIOReaderSeek(self.handle, pos)) From 5442fe29ab74cd69b6cd946fa23c44b05bf0ae0d Mon Sep 17 00:00:00 2001 From: Joshua Zhang Date: Sun, 25 Nov 2018 21:30:44 -0800 Subject: [PATCH 09/12] add docs to iters --- python/mxnet/gluon/data/dataloader.py | 75 +++++++++++++++++++++++++-- 1 file changed, 72 insertions(+), 3 deletions(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 88a3059bcf96..fd2a2c21b173 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -250,13 +250,21 @@ def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory= self._reset() def __len__(self): + """Get length of iterator. + + Returns + ------- + int + Length of iterator, equals to batch_sampler length. + + """ return len(self._batch_sampler) def __del__(self): self.shutdown() def _reset(self): - """Reset iterator with multiprocessing workers alive.""" + """Reset iterator with multiprocessing workers alive. Internal use. """ assert not self._shutdown, "call reset after shutdown is forbidden" # clear key queue removed_idx = set() @@ -286,7 +294,7 @@ def _reset(self): self._push_next() def _push_next(self): - """Assign next batch workload to workers.""" + """Assign next batch workload to workers. Internal use only. """ r = next(self._iter, None) if r is None: return @@ -294,6 +302,14 @@ def _push_next(self): self._sent_idx += 1 def __next__(self): + """Return next sample, will raise `StopIteration` reaching end. + + Returns + ------- + NDArray + Batched sample data. + + """ assert not self._shutdown, "call __next__ after shutdown is forbidden" if self._rcvd_idx == self._sent_idx: assert not self._data_buffer, "Data buffer should be empty at this moment" @@ -308,14 +324,35 @@ def __next__(self): return batch def next(self): + """Compatible portal for __next__ in python2. + + Returns + ------- + type + Description of returned object. + + """ return self.__next__() def __iter__(self): + """Requiring iterator will reset current instance, but keep all workers + alive, thus save re-init time of forking processes. + + Returns + ------- + iterator + Iterator of self. + + """ self._reset() return self def shutdown(self): - """Shutdown internal workers by pushing terminate signals.""" + """ + Shutdown internal workers by pushing terminate signals. Once shutdown, + you cannot use this instance again, you will need to obtain a new + _MultiWorkerIter by `iter(dataloader)`. + """ if not self._shutdown: # send shutdown signal to the fetcher and join data queue first # Remark: loop_fetcher need to be joined prior to the workers. @@ -360,6 +397,14 @@ def __init__(self, dataset, batchify_fn, batch_sampler, pin_memory=False): self._reset() def __len__(self): + """Get length of iterator. + + Returns + ------- + int + Length of iterator, equals to batch_sampler length. + + """ return len(self._batch_sampler) def _reset(self): @@ -367,6 +412,14 @@ def _reset(self): self._iter = iter(self._batch_sampler) def __next__(self): + """Return next sample, will raise `StopIteration` reaching end. + + Returns + ------- + NDArray + Batched sample data. + + """ try: batch = next(self._iter) except StopIteration: @@ -378,9 +431,25 @@ def __next__(self): return ret def next(self): + """Compatible portal for __next__ in python2. + + Returns + ------- + type + Description of returned object. + + """ return self.__next__() def __iter__(self): + """Requiring iterator will reset current instance. + + Returns + ------- + iterator + Iterator of self. + + """ self._reset() return self From a30958a15518f55eeee69a167ec5a887af1f7707 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Mon, 26 Nov 2018 22:02:42 -0800 Subject: [PATCH 10/12] Modify docs to match behavior --- python/mxnet/gluon/data/dataloader.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index fd2a2c21b173..2fa64009d332 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -192,7 +192,8 @@ def fetcher_loop(data_queue, data_buffer, pin_memory=False, data_buffer_lock=Non class _MultiWorkerIter(object): """Interal multi-worker iterator for DataLoader. - It allow reset() to reuse iterator with all workers alive. + Re-acquiring this iterator by `iter()` function will reset it + with all workers alive in order to save re-initialization overhead. Parameters ---------- @@ -370,7 +371,8 @@ def shutdown(self): class _SameProcessIter(object): - """Same Process Iterator, which allow reset(). + """Same Process Iterator. + Re-acquire this iterator by `iter()` function will reset it. Parameters ---------- From 6cf6ceefe92fdefa7017516a2d493db918edc1ec Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Mon, 26 Nov 2018 22:21:34 -0800 Subject: [PATCH 11/12] Fix pylint --- python/mxnet/gluon/data/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 2fa64009d332..c483b48d4352 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -192,7 +192,7 @@ def fetcher_loop(data_queue, data_buffer, pin_memory=False, data_buffer_lock=Non class _MultiWorkerIter(object): """Interal multi-worker iterator for DataLoader. - Re-acquiring this iterator by `iter()` function will reset it + Re-acquiring this iterator by `iter()` function will reset it with all workers alive in order to save re-initialization overhead. Parameters From 67e1a0fc68c3521cb042b7aedf85b987ba6b0739 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Tue, 27 Nov 2018 17:37:36 -0800 Subject: [PATCH 12/12] address comments --- python/mxnet/gluon/data/dataloader.py | 36 +++++++++++++++++++-------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index c483b48d4352..a23ba580dd34 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -191,9 +191,9 @@ def fetcher_loop(data_queue, data_buffer, pin_memory=False, data_buffer_lock=Non class _MultiWorkerIter(object): - """Interal multi-worker iterator for DataLoader. - Re-acquiring this iterator by `iter()` function will reset it - with all workers alive in order to save re-initialization overhead. + """Internal multi-worker iterator for DataLoader. + Re-acquire this iterator by `iter()` function will reset it if previous iteration is finished. + All workers are still alive in order to save re-initialization overhead. Parameters ---------- @@ -212,6 +212,13 @@ class _MultiWorkerIter(object): If ``True``, the dataloader will copy NDArrays into pinned memory before returning them. Copying from CPU pinned memory to GPU is faster than from normal CPU memory. + worker_fn : callable + `worker_fn` is the multiprocess worker function to process data in worker processes. + It defaults to `worker_loop(dataset, key_queue, data_queue, batchify_fn)`. + `worker_fn` takes inputs of `dataset` for input data, `key_queue` for (idx, batch_sample) + from batch sampler, `data_queue` for storing processed batch data as `NDArray`, and + `batchify_fn` for explicit batching instructions. + It is not recommanded to customize `worker_fn` unless you have specific use cases. """ def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False, @@ -231,6 +238,7 @@ def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory= self._sent_idx = 0 self._iter = iter(self._batch_sampler) self._shutdown = False + self._stop = False workers = [] for _ in range(self._num_workers): @@ -289,6 +297,7 @@ def _reset(self): self._rcvd_idx = 0 self._sent_idx = 0 self._iter = iter(self._batch_sampler) + self._stop = False # pre-fetch for _ in range(2 * self._num_workers): @@ -314,6 +323,7 @@ def __next__(self): assert not self._shutdown, "call __next__ after shutdown is forbidden" if self._rcvd_idx == self._sent_idx: assert not self._data_buffer, "Data buffer should be empty at this moment" + self._stop = True raise StopIteration while True: @@ -329,8 +339,8 @@ def next(self): Returns ------- - type - Description of returned object. + NDArray + Batched sample data. """ return self.__next__() @@ -345,7 +355,9 @@ def __iter__(self): Iterator of self. """ - self._reset() + assert not self._shutdown, "get iterator after shutdown is forbidden" + if self._stop: + self._reset() return self def shutdown(self): @@ -372,7 +384,7 @@ def shutdown(self): class _SameProcessIter(object): """Same Process Iterator. - Re-acquire this iterator by `iter()` function will reset it. + Re-acquire this iterator by `iter()` function will reset it if previous iteration is finished. Parameters ---------- @@ -396,6 +408,7 @@ def __init__(self, dataset, batchify_fn, batch_sampler, pin_memory=False): self._batchify_fn = batchify_fn self._batch_sampler = batch_sampler self._pin_memory = pin_memory + self._stop = False self._reset() def __len__(self): @@ -412,6 +425,7 @@ def __len__(self): def _reset(self): """Reset iterator.""" self._iter = iter(self._batch_sampler) + self._stop = False def __next__(self): """Return next sample, will raise `StopIteration` reaching end. @@ -425,6 +439,7 @@ def __next__(self): try: batch = next(self._iter) except StopIteration: + self._stop = True raise StopIteration else: ret = self._batchify_fn([self._dataset[idx] for idx in batch]) @@ -437,8 +452,8 @@ def next(self): Returns ------- - type - Description of returned object. + NDArray + Batched sample data. """ return self.__next__() @@ -452,7 +467,8 @@ def __iter__(self): Iterator of self. """ - self._reset() + if self._stop: + self._reset() return self