Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 11 additions & 23 deletions deepspeed/pt/deepspeed_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Copyright 2019 The Microsoft DeepSpeed Team
'''

import math
import torch
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
Expand All @@ -17,6 +18,7 @@ def __init__(self,
collate_fn=None,
num_local_io_workers=None,
data_sampler=None):

self.tput_timer = tput_timer
self.batch_size = batch_size

Expand All @@ -40,35 +42,21 @@ def __init__(self,
self.device_count = device_count
self.batch_size = batch_size
self.pin_memory = pin_memory
self.len = len(self.data_sampler)
self.data = None
self.dataloader = DataLoader(self.dataset,
batch_size=self.batch_size,
pin_memory=self.pin_memory,
sampler=self.data_sampler,
collate_fn=None if self.collate_fn is None else self.collate_fn,
num_workers=self.num_local_io_workers)
self._iter_data_loader = iter(self.dataloader)

def __iter__(self):
self._create_dataloader()
return self

def __len__(self):
return self.len
return len(self.dataloader)

def __next__(self):
if self.tput_timer:
self.tput_timer.start()
return next(self.data)

def _create_dataloader(self):
if self.collate_fn is None:
self.dataloader = DataLoader(self.dataset,
batch_size=self.batch_size,
pin_memory=self.pin_memory,
sampler=self.data_sampler,
num_workers=self.num_local_io_workers)
else:
self.dataloader = DataLoader(self.dataset,
batch_size=self.batch_size,
pin_memory=self.pin_memory,
sampler=self.data_sampler,
collate_fn=self.collate_fn,
num_workers=self.num_local_io_workers)
self.data = (x for x in self.dataloader)

return self.dataloader
return next(self._iter_data_loader)