Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 11 additions & 25 deletions deepspeed/pt/deepspeed_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ def __init__(self,
collate_fn=None,
num_local_io_workers=None,
data_sampler=None):
self.tput_timer = tput_timer
self.batch_size = batch_size

if local_rank >= 0:
if data_sampler is None:
Expand All @@ -33,42 +31,30 @@ def __init__(self,
if num_local_io_workers is None:
num_local_io_workers = 2 * device_count

self.tput_timer = tput_timer
self.batch_size = batch_size
self.num_local_io_workers = num_local_io_workers
self.data_sampler = data_sampler
self.dataset = dataset
self.collate_fn = collate_fn
self.device_count = device_count
self.batch_size = batch_size
self.pin_memory = pin_memory
self.len = len(self.data_sampler)
self.data = None
self.dataloader = DataLoader(self.dataset,
batch_size=self.batch_size,
pin_memory=self.pin_memory,
sampler=self.data_sampler,
collate_fn=None if self.collate_fn is None else self.collate_fn,
num_workers=self.num_local_io_workers)
self._iter_data_loader = iter(self.dataloader)

def __iter__(self):
self._create_dataloader()
return self

def __len__(self):
return self.len
return len(self.dataloader)

def __next__(self):
if self.tput_timer:
self.tput_timer.start()
return next(self.data)

def _create_dataloader(self):
if self.collate_fn is None:
self.dataloader = DataLoader(self.dataset,
batch_size=self.batch_size,
pin_memory=self.pin_memory,
sampler=self.data_sampler,
num_workers=self.num_local_io_workers)
else:
self.dataloader = DataLoader(self.dataset,
batch_size=self.batch_size,
pin_memory=self.pin_memory,
sampler=self.data_sampler,
collate_fn=self.collate_fn,
num_workers=self.num_local_io_workers)
self.data = (x for x in self.dataloader)

return self.dataloader
return next(self._iter_data_loader)