Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions openspeech/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,24 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import os
import numpy as np

from torch.utils.data import Sampler

from .audio.load import load_audio


class BucketingSampler(Sampler):
class RandomSampler(Sampler):
r"""
Samples batches assuming they are in order of size to batch similarly sized samples together.
Implementation of a Random Sampler for sampling the dataset.

Args:
data_source (torch.utils.data.Dataset): dataset to sample from
batch_size (int): size of batch
drop_last (bool): flat indication whether to drop last batch or not
"""
def __init__(self, data_source, batch_size: int = 32, drop_last: bool = False) -> None:
super(BucketingSampler, self).__init__(data_source)
super(RandomSampler, self).__init__(data_source)
self.batch_size = batch_size
self.data_source = data_source
ids = list(range(0, len(data_source)))
Expand All @@ -52,3 +54,42 @@ def __len__(self):

def shuffle(self, epoch):
np.random.shuffle(self.bins)


class SmartBatchingSampler(Sampler):
"""
Batching with similar sequence length.

Args:
data_source (torch.utils.data.Dataset): dataset to sample from
batch_size (int): size of batch
drop_last (bool): flat indication whether to drop last batch or not
"""
def __init__(self, data_source, batch_size: int = 32, drop_last: bool = False) -> None:
super(SmartBatchingSampler, self).__init__(data_source)
self.batch_size = batch_size
self.data_source = data_source

audio_lengths = [self._get_audio_length(audio_path) for audio_path in data_source.audio_paths]
audio_indices = [idx for idx in range(len(data_source.audio_paths))]

pack_by_length = list(zip(audio_lengths, audio_indices))
sort_by_length = sorted(pack_by_length)
audio_lengths, audio_indices = zip(*sort_by_length)

self.bins = [audio_indices[i:i + batch_size] for i in range(0, len(audio_indices), batch_size)]
self.drop_last = drop_last

def __iter__(self):
for ids in self.bins:
np.random.shuffle(ids)
yield ids

def _get_audio_length(self, audio_path):
return len(load_audio(os.path.join(self.data_source.dataset_path, audio_path)))

def __len__(self):
return len(self.bins)

def shuffle(self, epoch):
np.random.shuffle(self.bins)
4 changes: 4 additions & 0 deletions openspeech/dataclass/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ class BaseTrainerConfigs(OpenspeechDataclass):
default="binsearch", metadata={"help": "If set to True, will initially run a batch size finder trying to find "
"the largest batch size that fits into memory."}
)
sampler: str = field(
default="smart", metadata={"help": "smart: batching with similar sequence length."
"else: random batch"}
)


@dataclass
Expand Down
19 changes: 10 additions & 9 deletions openspeech/datasets/aishell/lit_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,12 @@
import logging
from omegaconf import DictConfig
from typing import Optional, Tuple
from torch.utils.data import DataLoader

from openspeech.data.audio.dataset import SpeechToTextDataset
from openspeech.datasets import register_data_module
from openspeech.data.sampler import BucketingSampler
from openspeech.data.sampler import RandomSampler, SmartBatchingSampler
from openspeech.data.audio.data_loader import AudioDataLoader
from openspeech.tokenizers.tokenizer import Tokenizer
from openspeech.tokenizers import TOKENIZER_REGISTRY
from openspeech.datasets.aishell.preprocess import (
generate_character_labels,
generate_character_script,
Expand Down Expand Up @@ -158,24 +156,27 @@ def setup(self, stage: Optional[str] = None, tokenizer: Tokenizer = None):
del_silence=self.configs.audio.del_silence if stage == 'train' else False,
)

def train_dataloader(self) -> DataLoader:
train_sampler = BucketingSampler(self.dataset['train'], batch_size=self.configs.trainer.batch_size)
def train_dataloader(self) -> AudioDataLoader:
sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler
train_sampler = sampler(data_source=self.dataset['train'], batch_size=self.configs.trainer.batch_size)
return AudioDataLoader(
dataset=self.dataset['train'],
num_workers=self.configs.trainer.num_workers,
batch_sampler=train_sampler,
)

def val_dataloader(self) -> DataLoader:
valid_sampler = BucketingSampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size)
def val_dataloader(self) -> AudioDataLoader:
sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler
valid_sampler = sampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size)
return AudioDataLoader(
dataset=self.dataset['valid'],
num_workers=self.configs.trainer.num_workers,
batch_sampler=valid_sampler,
)

def test_dataloader(self) -> DataLoader:
test_sampler = BucketingSampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size)
def test_dataloader(self) -> AudioDataLoader:
sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler
test_sampler = sampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size)
return AudioDataLoader(
dataset=self.dataset['test'],
num_workers=self.configs.trainer.num_workers,
Expand Down
19 changes: 10 additions & 9 deletions openspeech/datasets/ksponspeech/lit_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,15 @@
import pytorch_lightning as pl
from typing import Optional
from omegaconf import DictConfig
from openspeech.data.audio.dataset import SpeechToTextDataset

from openspeech.data.audio.dataset import SpeechToTextDataset
from openspeech.datasets import register_data_module
from openspeech.data.sampler import BucketingSampler
from openspeech.data.sampler import RandomSampler, SmartBatchingSampler
from openspeech.data.audio.data_loader import AudioDataLoader
from openspeech.datasets.ksponspeech.preprocess.preprocess import preprocess, preprocess_test_data
from openspeech.datasets.ksponspeech.preprocess.character import generate_character_script, generate_character_labels
from openspeech.datasets.ksponspeech.preprocess.grapheme import sentence_to_grapheme
from openspeech.datasets.ksponspeech.preprocess.subword import train_sentencepiece, sentence_to_subwords
from openspeech.tokenizers import TOKENIZER_REGISTRY
from openspeech.tokenizers.tokenizer import Tokenizer


Expand All @@ -49,6 +48,8 @@ class LightningKsponSpeechDataModule(pl.LightningDataModule):

Attributes:
KSPONSPEECH_TRAIN_NUM (int): the number of KsponSpeech's train data.
KSPONSPEECH_VALID_NUM (int): the number of KsponSpeech's validation data.
KSPONSPEECH_TEST_NUM (int): the number of KsponSpeech's test data.

Args:
configs (DictConfig): configuration set.
Expand Down Expand Up @@ -173,26 +174,26 @@ def setup(self, stage: Optional[str] = None, tokenizer: Tokenizer = None):
)

def train_dataloader(self) -> AudioDataLoader:
r""" Return data loader for training. """
train_sampler = BucketingSampler(self.dataset['train'], batch_size=self.configs.trainer.batch_size)
sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler
train_sampler = sampler(data_source=self.dataset['train'], batch_size=self.configs.trainer.batch_size)
return AudioDataLoader(
dataset=self.dataset['train'],
num_workers=self.configs.trainer.num_workers,
batch_sampler=train_sampler,
)

def val_dataloader(self) -> AudioDataLoader:
r""" Return data loader for validation. """
valid_sampler = BucketingSampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size)
sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler
valid_sampler = sampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size)
return AudioDataLoader(
dataset=self.dataset['valid'],
num_workers=self.configs.trainer.num_workers,
batch_sampler=valid_sampler,
)

def test_dataloader(self) -> AudioDataLoader:
r""" Return data loader for training. """
test_sampler = BucketingSampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size)
sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler
test_sampler = sampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size)
return AudioDataLoader(
dataset=self.dataset['test'],
num_workers=self.configs.trainer.num_workers,
Expand Down
8 changes: 4 additions & 4 deletions openspeech/datasets/language_model/lit_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from omegaconf import DictConfig
from typing import Optional

from openspeech.data.sampler import BucketingSampler
from openspeech.data.sampler import RandomSampler
from openspeech.data.text.data_loader import TextDataLoader
from openspeech.data.text.dataset import TextDataset
from openspeech.datasets import register_data_module
Expand Down Expand Up @@ -78,7 +78,7 @@ def setup(self, stage: Optional[str] = None, tokenizer: Tokenizer = None):
)

def train_dataloader(self) -> TextDataLoader:
train_sampler = BucketingSampler(self.dataset['train'], batch_size=self.configs.trainer.batch_size)
train_sampler = RandomSampler(self.dataset['train'], batch_size=self.configs.trainer.batch_size)
return TextDataLoader(
dataset=self.dataset['train'],
num_workers=self.configs.trainer.num_workers,
Expand All @@ -87,7 +87,7 @@ def train_dataloader(self) -> TextDataLoader:

def val_dataloader(self) -> TextDataLoader:
r""" Return data loader for validation. """
valid_sampler = BucketingSampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size)
valid_sampler = RandomSampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size)
return TextDataLoader(
dataset=self.dataset['valid'],
num_workers=self.configs.trainer.num_workers,
Expand All @@ -96,7 +96,7 @@ def val_dataloader(self) -> TextDataLoader:

def test_dataloader(self) -> TextDataLoader:
r""" Return data loader for training. """
train_sampler = BucketingSampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size)
train_sampler = RandomSampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size)
return TextDataLoader(
dataset=self.dataset['test'],
num_workers=self.configs.trainer.num_workers,
Expand Down
21 changes: 11 additions & 10 deletions openspeech/datasets/librispeech/lit_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,11 @@
import pytorch_lightning as pl
from typing import Tuple, Optional
from omegaconf import DictConfig
from openspeech.data.audio.dataset import SpeechToTextDataset
from torch.utils.data import DataLoader

from openspeech.data.audio.dataset import SpeechToTextDataset
from openspeech.datasets import register_data_module
from openspeech.tokenizers import TOKENIZER_REGISTRY
from openspeech.tokenizers.tokenizer import Tokenizer
from openspeech.data.sampler import BucketingSampler
from openspeech.data.sampler import RandomSampler, SmartBatchingSampler
from openspeech.data.audio.data_loader import AudioDataLoader


Expand Down Expand Up @@ -188,24 +186,27 @@ def setup(self, stage: Optional[str] = None, tokenizer: Tokenizer = None) -> Non
del_silence=self.configs.audio.del_silence if stage == 'train' else False,
)

def train_dataloader(self) -> DataLoader:
train_sampler = BucketingSampler(self.dataset['train'], batch_size=self.configs.trainer.batch_size)
def train_dataloader(self) -> AudioDataLoader:
sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler
train_sampler = sampler(data_source=self.dataset['train'], batch_size=self.configs.trainer.batch_size)
return AudioDataLoader(
dataset=self.dataset['train'],
num_workers=self.configs.trainer.num_workers,
batch_sampler=train_sampler,
)

def val_dataloader(self) -> DataLoader:
valid_sampler = BucketingSampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size)
def val_dataloader(self) -> AudioDataLoader:
sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler
valid_sampler = sampler(self.dataset['valid'], batch_size=self.configs.trainer.batch_size)
return AudioDataLoader(
dataset=self.dataset['valid'],
num_workers=self.configs.trainer.num_workers,
batch_sampler=valid_sampler,
)

def test_dataloader(self) -> DataLoader:
test_sampler = BucketingSampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size)
def test_dataloader(self) -> AudioDataLoader:
sampler = SmartBatchingSampler if self.configs.trainer.sampler == 'smart' else RandomSampler
test_sampler = sampler(self.dataset['test'], batch_size=self.configs.trainer.batch_size)
return AudioDataLoader(
dataset=self.dataset['test'],
num_workers=self.configs.trainer.num_workers,
Expand Down