diff --git a/deepspeed/pt/deepspeed_dataloader.py b/deepspeed/pt/deepspeed_dataloader.py index d548cd6dbf2f..083c4685f1f9 100644 --- a/deepspeed/pt/deepspeed_dataloader.py +++ b/deepspeed/pt/deepspeed_dataloader.py @@ -17,8 +17,6 @@ def __init__(self, collate_fn=None, num_local_io_workers=None, data_sampler=None): - self.tput_timer = tput_timer - self.batch_size = batch_size if local_rank >= 0: if data_sampler is None: @@ -33,6 +31,8 @@ def __init__(self, if num_local_io_workers is None: num_local_io_workers = 2 * device_count + self.tput_timer = tput_timer + self.batch_size = batch_size self.num_local_io_workers = num_local_io_workers self.data_sampler = data_sampler self.dataset = dataset @@ -40,35 +40,21 @@ def __init__(self, self.device_count = device_count self.batch_size = batch_size self.pin_memory = pin_memory - self.len = len(self.data_sampler) - self.data = None + self.dataloader = DataLoader(self.dataset, + batch_size=self.batch_size, + pin_memory=self.pin_memory, + sampler=self.data_sampler, + collate_fn=None if self.collate_fn is None else self.collate_fn, + num_workers=self.num_local_io_workers) + self._iter_data_loader = iter(self.dataloader) def __iter__(self): - self._create_dataloader() return self def __len__(self): - return self.len + return len(self.dataloader) def __next__(self): if self.tput_timer: self.tput_timer.start() - return next(self.data) - - def _create_dataloader(self): - if self.collate_fn is None: - self.dataloader = DataLoader(self.dataset, - batch_size=self.batch_size, - pin_memory=self.pin_memory, - sampler=self.data_sampler, - num_workers=self.num_local_io_workers) - else: - self.dataloader = DataLoader(self.dataset, - batch_size=self.batch_size, - pin_memory=self.pin_memory, - sampler=self.data_sampler, - collate_fn=self.collate_fn, - num_workers=self.num_local_io_workers) - self.data = (x for x in self.dataloader) - - return self.dataloader + return next(self._iter_data_loader)