Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions monai/data/thread_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ class ThreadBuffer:
timeout: Time to wait for an item from the buffer, or to wait while the buffer is full when adding items
"""

def __init__(self, src, buffer_size=1, timeout=0.01):
def __init__(self, src, buffer_size: int = 1, timeout: float = 0.01):
self.src = src
self.buffer_size = buffer_size
self.timeout = timeout
self.buffer = Queue(self.buffer_size)
self.buffer: Queue = Queue(self.buffer_size)
self.gen_thread = None
self.is_running = False

Expand Down Expand Up @@ -82,11 +82,27 @@ class ThreadDataLoader(DataLoader):
Subclass of `DataLoader` using a `ThreadBuffer` object to implement `__iter__` method asynchronously. This will
iterate over data from the loader as expected however the data is generated on a separate thread. Use this class
where a `DataLoader` instance is required and not just an iterable object.

Args:
dataset: input dataset.
buffer_size: number of items to buffer from the data source.
buffer_timeout: time to wait for an item from the buffer, or to wait while the buffer is full when adding items.
num_workers: number of the multi-prcessing workers in PyTorch DataLoader.

"""

def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs):
def __init__(
self,
dataset: Dataset,
buffer_size: int = 1,
buffer_timeout: float = 0.01,
num_workers: int = 0,
**kwargs,
):
super().__init__(dataset, num_workers, **kwargs)
self.buffer_size = buffer_size
self.buffer_timeout = buffer_timeout

def __iter__(self):
buffer = ThreadBuffer(super().__iter__())
buffer = ThreadBuffer(src=super().__iter__(), buffer_size=self.buffer_size, timeout=self.buffer_timeout)
yield from buffer