From 54e3c8b2b77904656efa958cf8ecc19a3e25e312 Mon Sep 17 00:00:00 2001 From: Soohwan Kim Date: Sun, 29 Aug 2021 20:01:27 +0900 Subject: [PATCH] Add uniform-length batching (smart batching) [resolved #82] - Soohwan Kim --- openspeech/data/sampler.py | 49 +++++++++++++++++-- openspeech/dataclass/configurations.py | 4 ++ .../datasets/aishell/lit_data_module.py | 19 +++---- .../datasets/ksponspeech/lit_data_module.py | 19 +++---- .../language_model/lit_data_module.py | 8 +-- .../datasets/librispeech/lit_data_module.py | 21 ++++---- 6 files changed, 84 insertions(+), 36 deletions(-) diff --git a/openspeech/data/sampler.py b/openspeech/data/sampler.py index e1d1e80..3361ee5 100644 --- a/openspeech/data/sampler.py +++ b/openspeech/data/sampler.py @@ -20,14 +20,16 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import os import numpy as np - from torch.utils.data import Sampler +from .audio.load import load_audio + -class BucketingSampler(Sampler): +class RandomSampler(Sampler): r""" - Samples batches assuming they are in order of size to batch similarly sized samples together. + Implementation of a Random Sampler for sampling the dataset. Args: data_source (torch.utils.data.Dataset): dataset to sample from @@ -35,7 +37,7 @@ class BucketingSampler(Sampler): drop_last (bool): flat indication whether to drop last batch or not """ def __init__(self, data_source, batch_size: int = 32, drop_last: bool = False) -> None: - super(BucketingSampler, self).__init__(data_source) + super(RandomSampler, self).__init__(data_source) self.batch_size = batch_size self.data_source = data_source ids = list(range(0, len(data_source))) @@ -52,3 +54,42 @@ def __len__(self): def shuffle(self, epoch): np.random.shuffle(self.bins) + + +class SmartBatchingSampler(Sampler): + """ + Batching with similar sequence length. + + Args: + data_source (torch.utils.data.Dataset): dataset to sample from + batch_size (int): size of batch + drop_last (bool): flat indication whether to drop last batch or not + """ + def __init__(self, data_source, batch_size: int = 32, drop_last: bool = False) -> None: + super(SmartBatchingSampler, self).__init__(data_source) + self.batch_size = batch_size + self.data_source = data_source + + audio_lengths = [self._get_audio_length(audio_path) for audio_path in data_source.audio_paths] + audio_indices = [idx for idx in range(len(data_source.audio_paths))] + + pack_by_length = list(zip(audio_lengths, audio_indices)) + sort_by_length = sorted(pack_by_length) + audio_lengths, audio_indices = zip(*sort_by_length) + + self.bins = [audio_indices[i:i + batch_size] for i in range(0, len(audio_indices), batch_size)] + self.drop_last = drop_last + + def __iter__(self): + for ids in self.bins: + np.random.shuffle(ids) + yield ids + + def _get_audio_length(self, audio_path): + return len(load_audio(os.path.join(self.data_source.dataset_path, audio_path))) + + def __len__(self): + return len(self.bins) + + def shuffle(self, epoch): + np.random.shuffle(self.bins) diff --git a/openspeech/dataclass/configurations.py b/openspeech/dataclass/configurations.py index 0719f94..7d9855a 100644 --- a/openspeech/dataclass/configurations.py +++ b/openspeech/dataclass/configurations.py @@ -209,6 +209,10 @@ class BaseTrainerConfigs(OpenspeechDataclass): default="binsearch", metadata={"help": "If set to True, will initially run a batch size finder trying to find " "the largest batch size that fits into memory."} ) + sampler: str = field( + default="smart", metadata={"help": "smart: batching with similar sequence length." + "else: random batch"} + ) @dataclass diff --git a/openspeech/datasets/aishell/lit_data_module.py b/openspeech/datasets/aishell/lit_data_module.py index 8db4462..7951753 100644 --- a/openspeech/datasets/aishell/lit_data_module.py +++ b/openspeech/datasets/aishell/lit_data_module.py @@ -27,14 +27,12 @@ import logging from omegaconf import DictConfig from typing import Optional, Tuple -from torch.utils.data import DataLoader from openspeech.data.audio.dataset import SpeechToTextDataset from openspeech.datasets import register_data_module -from openspeech.data.sampler import BucketingSampler +from openspeech.data.sampler import RandomSampler, SmartBatchingSampler from openspeech.data.audio.data_loader import AudioDataLoader from openspeech.tokenizers.tokenizer import Tokenizer -from openspeech.tokenizers import TOKENIZER_REGISTRY from openspeech.datasets.aishell.preprocess import ( generate_character_labels, generate_character_script, @@ -158,24 +156,27 @@ def setup(self, stage: Optional[str] = None, tokenizer: Tokenizer = None): del_silence=self.configs.audio.del_silence if stage == 'train' else False, ) - def train_dataloader(self) -> DataLoader: - train_sampler = BucketingSampler(self.dataset['train'], batch_size=self.configs.trainer.batch_size) + def train_dataloader(self) -> AudioDataLoader: + sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler + train_sampler = sampler(data_source=self.dataset['train'], batch_size=self.configs.trainer.batch_size) return AudioDataLoader( dataset=self.dataset['train'], num_workers=self.configs.trainer.num_workers, batch_sampler=train_sampler, ) - def val_dataloader(self) -> DataLoader: - valid_sampler = BucketingSampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size) + def val_dataloader(self) -> AudioDataLoader: + sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler + valid_sampler = sampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size) return AudioDataLoader( dataset=self.dataset['valid'], num_workers=self.configs.trainer.num_workers, batch_sampler=valid_sampler, ) - def test_dataloader(self) -> DataLoader: - test_sampler = BucketingSampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size) + def test_dataloader(self) -> AudioDataLoader: + sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler + test_sampler = sampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size) return AudioDataLoader( dataset=self.dataset['test'], num_workers=self.configs.trainer.num_workers, diff --git a/openspeech/datasets/ksponspeech/lit_data_module.py b/openspeech/datasets/ksponspeech/lit_data_module.py index 9112f52..c69c54f 100644 --- a/openspeech/datasets/ksponspeech/lit_data_module.py +++ b/openspeech/datasets/ksponspeech/lit_data_module.py @@ -25,16 +25,15 @@ import pytorch_lightning as pl from typing import Optional from omegaconf import DictConfig -from openspeech.data.audio.dataset import SpeechToTextDataset +from openspeech.data.audio.dataset import SpeechToTextDataset from openspeech.datasets import register_data_module -from openspeech.data.sampler import BucketingSampler +from openspeech.data.sampler import RandomSampler, SmartBatchingSampler from openspeech.data.audio.data_loader import AudioDataLoader from openspeech.datasets.ksponspeech.preprocess.preprocess import preprocess, preprocess_test_data from openspeech.datasets.ksponspeech.preprocess.character import generate_character_script, generate_character_labels from openspeech.datasets.ksponspeech.preprocess.grapheme import sentence_to_grapheme from openspeech.datasets.ksponspeech.preprocess.subword import train_sentencepiece, sentence_to_subwords -from openspeech.tokenizers import TOKENIZER_REGISTRY from openspeech.tokenizers.tokenizer import Tokenizer @@ -49,6 +48,8 @@ class LightningKsponSpeechDataModule(pl.LightningDataModule): Attributes: KSPONSPEECH_TRAIN_NUM (int): the number of KsponSpeech's train data. + KSPONSPEECH_VALID_NUM (int): the number of KsponSpeech's validation data. + KSPONSPEECH_TEST_NUM (int): the number of KsponSpeech's test data. Args: configs (DictConfig): configuration set. @@ -173,8 +174,8 @@ def setup(self, stage: Optional[str] = None, tokenizer: Tokenizer = None): ) def train_dataloader(self) -> AudioDataLoader: - r""" Return data loader for training. """ - train_sampler = BucketingSampler(self.dataset['train'], batch_size=self.configs.trainer.batch_size) + sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler + train_sampler = sampler(data_source=self.dataset['train'], batch_size=self.configs.trainer.batch_size) return AudioDataLoader( dataset=self.dataset['train'], num_workers=self.configs.trainer.num_workers, @@ -182,8 +183,8 @@ def train_dataloader(self) -> AudioDataLoader: ) def val_dataloader(self) -> AudioDataLoader: - r""" Return data loader for validation. """ - valid_sampler = BucketingSampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size) + sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler + valid_sampler = sampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size) return AudioDataLoader( dataset=self.dataset['valid'], num_workers=self.configs.trainer.num_workers, @@ -191,8 +192,8 @@ def val_dataloader(self) -> AudioDataLoader: ) def test_dataloader(self) -> AudioDataLoader: - r""" Return data loader for training. """ - test_sampler = BucketingSampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size) + sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler + test_sampler = sampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size) return AudioDataLoader( dataset=self.dataset['test'], num_workers=self.configs.trainer.num_workers, diff --git a/openspeech/datasets/language_model/lit_data_module.py b/openspeech/datasets/language_model/lit_data_module.py index 3d3c847..e1413ad 100644 --- a/openspeech/datasets/language_model/lit_data_module.py +++ b/openspeech/datasets/language_model/lit_data_module.py @@ -27,7 +27,7 @@ from omegaconf import DictConfig from typing import Optional -from openspeech.data.sampler import BucketingSampler +from openspeech.data.sampler import RandomSampler from openspeech.data.text.data_loader import TextDataLoader from openspeech.data.text.dataset import TextDataset from openspeech.datasets import register_data_module @@ -78,7 +78,7 @@ def setup(self, stage: Optional[str] = None, tokenizer: Tokenizer = None): ) def train_dataloader(self) -> TextDataLoader: - train_sampler = BucketingSampler(self.dataset['train'], batch_size=self.configs.trainer.batch_size) + train_sampler = RandomSampler(self.dataset['train'], batch_size=self.configs.trainer.batch_size) return TextDataLoader( dataset=self.dataset['train'], num_workers=self.configs.trainer.num_workers, @@ -87,7 +87,7 @@ def train_dataloader(self) -> TextDataLoader: def val_dataloader(self) -> TextDataLoader: r""" Return data loader for validation. """ - valid_sampler = BucketingSampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size) + valid_sampler = RandomSampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size) return TextDataLoader( dataset=self.dataset['valid'], num_workers=self.configs.trainer.num_workers, @@ -96,7 +96,7 @@ def val_dataloader(self) -> TextDataLoader: def test_dataloader(self) -> TextDataLoader: r""" Return data loader for training. """ - train_sampler = BucketingSampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size) + train_sampler = RandomSampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size) return TextDataLoader( dataset=self.dataset['test'], num_workers=self.configs.trainer.num_workers, diff --git a/openspeech/datasets/librispeech/lit_data_module.py b/openspeech/datasets/librispeech/lit_data_module.py index 384a9e1..48f8914 100644 --- a/openspeech/datasets/librispeech/lit_data_module.py +++ b/openspeech/datasets/librispeech/lit_data_module.py @@ -28,13 +28,11 @@ import pytorch_lightning as pl from typing import Tuple, Optional from omegaconf import DictConfig -from openspeech.data.audio.dataset import SpeechToTextDataset -from torch.utils.data import DataLoader +from openspeech.data.audio.dataset import SpeechToTextDataset from openspeech.datasets import register_data_module -from openspeech.tokenizers import TOKENIZER_REGISTRY from openspeech.tokenizers.tokenizer import Tokenizer -from openspeech.data.sampler import BucketingSampler +from openspeech.data.sampler import RandomSampler, SmartBatchingSampler from openspeech.data.audio.data_loader import AudioDataLoader @@ -188,24 +186,27 @@ def setup(self, stage: Optional[str] = None, tokenizer: Tokenizer = None) -> Non del_silence=self.configs.audio.del_silence if stage == 'train' else False, ) - def train_dataloader(self) -> DataLoader: - train_sampler = BucketingSampler(self.dataset['train'], batch_size=self.configs.trainer.batch_size) + def train_dataloader(self) -> AudioDataLoader: + sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler + train_sampler = sampler(data_source=self.dataset['train'], batch_size=self.configs.trainer.batch_size) return AudioDataLoader( dataset=self.dataset['train'], num_workers=self.configs.trainer.num_workers, batch_sampler=train_sampler, ) - def val_dataloader(self) -> DataLoader: - valid_sampler = BucketingSampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size) + def val_dataloader(self) -> AudioDataLoader: + sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler + valid_sampler = sampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size) return AudioDataLoader( dataset=self.dataset['valid'], num_workers=self.configs.trainer.num_workers, batch_sampler=valid_sampler, ) - def test_dataloader(self) -> DataLoader: - test_sampler = BucketingSampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size) + def test_dataloader(self) -> AudioDataLoader: + sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler + test_sampler = sampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size) return AudioDataLoader( dataset=self.dataset['test'], num_workers=self.configs.trainer.num_workers,