diff --git a/monai/data/thread_buffer.py b/monai/data/thread_buffer.py index 2901335bd5..da5847465e 100644 --- a/monai/data/thread_buffer.py +++ b/monai/data/thread_buffer.py @@ -33,11 +33,11 @@ class ThreadBuffer: timeout: Time to wait for an item from the buffer, or to wait while the buffer is full when adding items """ - def __init__(self, src, buffer_size=1, timeout=0.01): + def __init__(self, src, buffer_size: int = 1, timeout: float = 0.01): self.src = src self.buffer_size = buffer_size self.timeout = timeout - self.buffer = Queue(self.buffer_size) + self.buffer: Queue = Queue(self.buffer_size) self.gen_thread = None self.is_running = False @@ -82,11 +82,27 @@ class ThreadDataLoader(DataLoader): Subclass of `DataLoader` using a `ThreadBuffer` object to implement `__iter__` method asynchronously. This will iterate over data from the loader as expected however the data is generated on a separate thread. Use this class where a `DataLoader` instance is required and not just an iterable object. + + Args: + dataset: input dataset. + buffer_size: number of items to buffer from the data source. + buffer_timeout: time to wait for an item from the buffer, or to wait while the buffer is full when adding items. + num_workers: number of the multi-prcessing workers in PyTorch DataLoader. + """ - def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs): + def __init__( + self, + dataset: Dataset, + buffer_size: int = 1, + buffer_timeout: float = 0.01, + num_workers: int = 0, + **kwargs, + ): super().__init__(dataset, num_workers, **kwargs) + self.buffer_size = buffer_size + self.buffer_timeout = buffer_timeout def __iter__(self): - buffer = ThreadBuffer(super().__iter__()) + buffer = ThreadBuffer(src=super().__iter__(), buffer_size=self.buffer_size, timeout=self.buffer_timeout) yield from buffer