diff --git a/deepspeed/pt/deepspeed_dataloader.py b/deepspeed/pt/deepspeed_dataloader.py index d548cd6dbf2f..d4a43dd60f87 100644 --- a/deepspeed/pt/deepspeed_dataloader.py +++ b/deepspeed/pt/deepspeed_dataloader.py @@ -2,6 +2,7 @@ Copyright 2019 The Microsoft DeepSpeed Team ''' +import math import torch from torch.utils.data import DataLoader, RandomSampler from torch.utils.data.distributed import DistributedSampler @@ -17,6 +18,7 @@ def __init__(self, collate_fn=None, num_local_io_workers=None, data_sampler=None): + self.tput_timer = tput_timer self.batch_size = batch_size @@ -40,35 +42,21 @@ def __init__(self, self.device_count = device_count self.batch_size = batch_size self.pin_memory = pin_memory - self.len = len(self.data_sampler) - self.data = None + self.dataloader = DataLoader(self.dataset, + batch_size=self.batch_size, + pin_memory=self.pin_memory, + sampler=self.data_sampler, + collate_fn=None if self.collate_fn is None else self.collate_fn, + num_workers=self.num_local_io_workers) + self._iter_data_loader = iter(self.dataloader) def __iter__(self): - self._create_dataloader() return self def __len__(self): - return self.len + return len(self.dataloader) def __next__(self): if self.tput_timer: self.tput_timer.start() - return next(self.data) - - def _create_dataloader(self): - if self.collate_fn is None: - self.dataloader = DataLoader(self.dataset, - batch_size=self.batch_size, - pin_memory=self.pin_memory, - sampler=self.data_sampler, - num_workers=self.num_local_io_workers) - else: - self.dataloader = DataLoader(self.dataset, - batch_size=self.batch_size, - pin_memory=self.pin_memory, - sampler=self.data_sampler, - collate_fn=self.collate_fn, - num_workers=self.num_local_io_workers) - self.data = (x for x in self.dataloader) - - return self.dataloader + return next(self._iter_data_loader)