From f79aed0ecf94738845280fcfb804702eea9de23d Mon Sep 17 00:00:00 2001 From: tomek Date: Tue, 24 Mar 2020 10:41:03 +0100 Subject: [PATCH 1/2] fix dataloader len for tqdm and others --- deepspeed/pt/deepspeed_dataloader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/pt/deepspeed_dataloader.py b/deepspeed/pt/deepspeed_dataloader.py index d548cd6dbf2f..6f495af89ebb 100644 --- a/deepspeed/pt/deepspeed_dataloader.py +++ b/deepspeed/pt/deepspeed_dataloader.py @@ -2,6 +2,7 @@ Copyright 2019 The Microsoft DeepSpeed Team ''' +import math import torch from torch.utils.data import DataLoader, RandomSampler from torch.utils.data.distributed import DistributedSampler @@ -40,7 +41,7 @@ def __init__(self, self.device_count = device_count self.batch_size = batch_size self.pin_memory = pin_memory - self.len = len(self.data_sampler) + self.len = int(math.ceil(len(self.data_sampler) * 1.0 / self.batch_size)) self.data = None def __iter__(self): From 155a958d2ebc91b1f52081815b811aa49240e428 Mon Sep 17 00:00:00 2001 From: tomek Date: Tue, 31 Mar 2020 13:26:03 +0200 Subject: [PATCH 2/2] fix deep speed data loader - different appraoch --- deepspeed/pt/deepspeed_dataloader.py | 33 +++++++++------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/deepspeed/pt/deepspeed_dataloader.py b/deepspeed/pt/deepspeed_dataloader.py index 6f495af89ebb..d4a43dd60f87 100644 --- a/deepspeed/pt/deepspeed_dataloader.py +++ b/deepspeed/pt/deepspeed_dataloader.py @@ -18,6 +18,7 @@ def __init__(self, collate_fn=None, num_local_io_workers=None, data_sampler=None): + self.tput_timer = tput_timer self.batch_size = batch_size @@ -41,35 +42,21 @@ def __init__(self, self.device_count = device_count self.batch_size = batch_size self.pin_memory = pin_memory - self.len = int(math.ceil(len(self.data_sampler) * 1.0 / self.batch_size)) - self.data = None + self.dataloader = DataLoader(self.dataset, + batch_size=self.batch_size, + pin_memory=self.pin_memory, + sampler=self.data_sampler, + collate_fn=None if self.collate_fn is None else self.collate_fn, + num_workers=self.num_local_io_workers) + self._iter_data_loader = iter(self.dataloader) def __iter__(self): - self._create_dataloader() return self def __len__(self): - return self.len + return len(self.dataloader) def __next__(self): if self.tput_timer: self.tput_timer.start() - return next(self.data) - - def _create_dataloader(self): - if self.collate_fn is None: - self.dataloader = DataLoader(self.dataset, - batch_size=self.batch_size, - pin_memory=self.pin_memory, - sampler=self.data_sampler, - num_workers=self.num_local_io_workers) - else: - self.dataloader = DataLoader(self.dataset, - batch_size=self.batch_size, - pin_memory=self.pin_memory, - sampler=self.data_sampler, - collate_fn=self.collate_fn, - num_workers=self.num_local_io_workers) - self.data = (x for x in self.dataloader) - - return self.dataloader + return next(self._iter_data_loader)