-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[ASR] Multichannel mask estimator with flex number of channels #7317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
titu1994
merged 2 commits into
NVIDIA-NeMo:main
from
anteju:pr/flex-channel-mask-estimator
Oct 13, 2023
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
146 changes: 146 additions & 0 deletions
146
examples/audio_tasks/conf/beamforming_flex_channels.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| # This configuration contains the exemplary values for training a multichannel speech enhancement model with a mask-based beamformer. | ||
| # | ||
| name: beamforming_flex_channels | ||
|
|
||
| model: | ||
| sample_rate: 16000 | ||
| skip_nan_grad: false | ||
| num_outputs: 1 | ||
|
|
||
| train_ds: | ||
| manifest_filepath: ??? | ||
| input_key: audio_filepath # key of the input signal path in the manifest | ||
| input_channel_selector: null # load all channels from the input file | ||
| target_key: target_anechoic_filepath # key of the target signal path in the manifest | ||
| target_channel_selector: 0 # load only the first channel from the target file | ||
| audio_duration: 4.0 # in seconds, audio segment duration for training | ||
| random_offset: true # if the file is longer than audio_duration, use random offset to select a subsegment | ||
| min_duration: ${model.train_ds.audio_duration} | ||
| batch_size: 16 # batch size may be increased based on the available memory | ||
| shuffle: true | ||
| num_workers: 16 | ||
| pin_memory: true | ||
|
|
||
| validation_ds: | ||
| manifest_filepath: ??? | ||
| input_key: audio_filepath # key of the input signal path in the manifest | ||
| input_channel_selector: null # load all channels from the input file | ||
| target_key: target_anechoic_filepath # key of the target signal path in the manifest | ||
| target_channel_selector: 0 # load only the first channel from the target file | ||
| batch_size: 8 | ||
| shuffle: false | ||
| num_workers: 8 | ||
| pin_memory: true | ||
|
|
||
| channel_augment: | ||
| _target_: nemo.collections.asr.parts.submodules.multichannel_modules.ChannelAugment | ||
| num_channels_min: 2 # minimal number of channels selected for each batch | ||
| num_channels_max: null # max number of channels is determined by the batch size | ||
| permute_channels: true | ||
|
|
||
| encoder: | ||
| _target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram | ||
| fft_length: 512 # Length of the window and FFT for calculating spectrogram | ||
| hop_length: 256 # Hop length for calculating spectrogram | ||
|
|
||
| decoder: | ||
| _target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio | ||
| fft_length: ${model.encoder.fft_length} | ||
| hop_length: ${model.encoder.hop_length} | ||
|
|
||
| mask_estimator: | ||
| _target_: nemo.collections.asr.modules.audio_modules.MaskEstimatorFlexChannels | ||
| num_outputs: ${model.num_outputs} # number of output masks | ||
| num_subbands: 257 # number of subbands for the input spectrogram | ||
| num_blocks: 5 # number of blocks in the model | ||
| channel_reduction_position: 3 # 0-indexed, apply channel reduction before this block | ||
| channel_reduction_type: average # channel-wise reduction | ||
| channel_block_type: transform_average_concatenate # channel block | ||
| temporal_block_type: conformer_encoder # temporal block | ||
| temporal_block_num_layers: 5 # number of layers for the temporal block | ||
| temporal_block_num_heads: 4 # number of heads for the temporal block | ||
| temporal_block_dimension: 128 # the hidden size of the temporal block | ||
| mag_reduction: null # channel-wise reduction of magnitude | ||
| mag_normalization: mean_var # normalization using mean and variance | ||
| use_ipd: true # use inter-channel phase difference | ||
| ipd_normalization: mean # mean normalization | ||
|
|
||
| mask_processor: | ||
| # Mask-based multi-channel processor | ||
| _target_: nemo.collections.asr.modules.audio_modules.MaskBasedBeamformer | ||
| filter_type: pmwf # parametric multichannel wiener filter | ||
| filter_beta: 0.0 # mvdr | ||
| filter_rank: one | ||
| ref_channel: max_snr # select reference channel by maximizing estimated SNR | ||
| ref_hard: 1 # a one-hot reference. If false, a soft estimate across channels is used. | ||
| ref_hard_use_grad: false # use straight-through gradient when using hard reference | ||
| ref_subband_weighting: false # use subband weighting for reference estimation | ||
| num_subbands: ${model.mask_estimator.num_subbands} | ||
|
|
||
| loss: | ||
| _target_: nemo.collections.asr.losses.SDRLoss | ||
| convolution_invariant: true # convolution-invariant loss | ||
| sdr_max: 30 # soft threshold for SDR | ||
|
|
||
| metrics: | ||
| val: | ||
| sdr_0: | ||
| _target_: torchmetrics.audio.SignalDistortionRatio | ||
| channel: 0 # evaluate only on channel 0, if there are multiple outputs | ||
|
|
||
| optim: | ||
| name: adamw | ||
| lr: 1e-4 | ||
| # optimizer arguments | ||
| betas: [0.9, 0.98] | ||
| weight_decay: 1e-3 | ||
|
|
||
| # scheduler setup | ||
| sched: | ||
| name: CosineAnnealing | ||
| # scheduler config override | ||
| warmup_steps: 10000 | ||
| warmup_ratio: null | ||
| min_lr: 1e-6 | ||
|
|
||
| trainer: | ||
| devices: -1 # number of GPUs, -1 would use all available GPUs | ||
| num_nodes: 1 | ||
| max_epochs: -1 | ||
| max_steps: -1 # computed at runtime if not set | ||
| val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations | ||
| accelerator: auto | ||
| strategy: ddp | ||
| accumulate_grad_batches: 1 | ||
| gradient_clip_val: null | ||
| precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. | ||
| log_every_n_steps: 25 # Interval of logging. | ||
| enable_progress_bar: true | ||
| num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it | ||
| check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs | ||
| sync_batchnorm: true | ||
| enable_checkpointing: False # Provided by exp_manager | ||
| logger: false # Provided by exp_manager | ||
|
|
||
| exp_manager: | ||
| exp_dir: null | ||
| name: ${name} | ||
| create_tensorboard_logger: true | ||
| create_checkpoint_callback: true | ||
| checkpoint_callback_params: | ||
| # in case of multiple validation sets, first one is used | ||
| monitor: "val_loss" | ||
| mode: "min" | ||
| save_top_k: 5 | ||
| always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints | ||
|
|
||
| resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.pyth | ||
| # you need to set these two to true to continue the training | ||
| resume_if_exists: false | ||
| resume_ignore_no_checkpoint: false | ||
|
|
||
| # You may use this section to create a W&B logger | ||
| create_wandb_logger: false | ||
| wandb_logger_kwargs: | ||
| name: null | ||
| project: null |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,14 +61,24 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): | |
| self.decoder = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.decoder) | ||
|
|
||
| if 'mixture_consistency' in self._cfg: | ||
| logging.debug('Using mixture consistency') | ||
| self.mixture_consistency = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mixture_consistency) | ||
| else: | ||
| logging.debug('Mixture consistency not used') | ||
| self.mixture_consistency = None | ||
|
|
||
| # Future enhancement: | ||
| # If subclasses need to modify the config before calling super() | ||
| # Check ASRBPE* classes do with their mixin | ||
|
|
||
| # Setup augmentation | ||
| if hasattr(self.cfg, 'channel_augment') and self.cfg.channel_augment is not None: | ||
| logging.debug('Using channel augmentation') | ||
| self.channel_augmentation = EncMaskDecAudioToAudioModel.from_config_dict(self.cfg.channel_augment) | ||
| else: | ||
| logging.debug('Channel augmentation not used') | ||
| self.channel_augmentation = None | ||
|
|
||
| # Setup optional Optimization flags | ||
| self.setup_optimization_flags() | ||
|
|
||
|
|
@@ -125,7 +135,7 @@ def process( | |
| temporary_manifest_filepath = os.path.join(tmpdir, 'manifest.json') | ||
| with open(temporary_manifest_filepath, 'w', encoding='utf-8') as fp: | ||
| for audio_file in paths2audio_files: | ||
| entry = {'input_filepath': audio_file, 'duration': librosa.get_duration(filename=audio_file)} | ||
| entry = {'input_filepath': audio_file, 'duration': librosa.get_duration(path=audio_file)} | ||
| fp.write(json.dumps(entry) + '\n') | ||
|
|
||
| config = { | ||
|
|
@@ -397,17 +407,23 @@ def training_step(self, batch, batch_idx): | |
| if target_signal.ndim == 2: | ||
| target_signal = target_signal.unsqueeze(1) | ||
|
|
||
| # Apply channel augmentation | ||
| if self.training and self.channel_augmentation is not None: | ||
| input_signal = self.channel_augmentation(input=input_signal) | ||
|
|
||
| # Process input | ||
| processed_signal, _ = self.forward(input_signal=input_signal, input_length=input_length) | ||
|
|
||
| loss_value = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length) | ||
| # Calculate the loss | ||
| loss = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length) | ||
|
|
||
| tensorboard_logs = { | ||
| 'train_loss': loss_value, | ||
| 'learning_rate': self._optimizer.param_groups[0]['lr'], | ||
| 'global_step': torch.tensor(self.trainer.global_step, dtype=torch.float32), | ||
| } | ||
| # Logs | ||
| self.log('train_loss', loss) | ||
| self.log('learning_rate', self._optimizer.param_groups[0]['lr']) | ||
| self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) | ||
|
|
||
| return {'loss': loss_value, 'log': tensorboard_logs} | ||
| # Return loss | ||
| return loss | ||
|
|
||
| def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): | ||
| input_signal, input_length, target_signal, target_length = batch | ||
|
|
@@ -419,11 +435,11 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = | |
| if target_signal.ndim == 2: | ||
| target_signal = target_signal.unsqueeze(1) | ||
|
|
||
| # Process input | ||
| processed_signal, _ = self.forward(input_signal=input_signal, input_length=input_length) | ||
|
|
||
| # Prepare output | ||
| loss_value = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length) | ||
| output_dict = {f'{tag}_loss': loss_value} | ||
| # Calculate the loss | ||
| loss = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length) | ||
|
|
||
| # Update metrics | ||
| if hasattr(self, 'metrics') and tag in self.metrics: | ||
|
|
@@ -432,19 +448,10 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = | |
| metric.update(preds=processed_signal, target=target_signal, input_length=input_length) | ||
|
|
||
| # Log global step | ||
| self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32), sync_dist=True) | ||
| self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) | ||
|
|
||
| if tag == 'val': | ||
| if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1: | ||
| self.validation_step_outputs[dataloader_idx].append(output_dict) | ||
| else: | ||
| self.validation_step_outputs.append(output_dict) | ||
| else: | ||
| if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1: | ||
| self.test_step_outputs[dataloader_idx].append(output_dict) | ||
| else: | ||
| self.test_step_outputs.append(output_dict) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved handling of multiple dataloaders to |
||
| return output_dict | ||
| # Return loss | ||
| return {f'{tag}_loss': loss} | ||
|
|
||
| @classmethod | ||
| def list_available_models(cls) -> Optional[PretrainedModelInfo]: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.