From 1577275d02b750899d8a499a074f6e735afab958 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sat, 17 Sep 2022 13:05:43 +0200 Subject: [PATCH 1/7] add vocoders --- src/diffusers/models/vocoders.py | 131 +++++++++++++++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 src/diffusers/models/vocoders.py diff --git a/src/diffusers/models/vocoders.py b/src/diffusers/models/vocoders.py new file mode 100644 index 000000000000..ea2a616463ad --- /dev/null +++ b/src/diffusers/models/vocoders.py @@ -0,0 +1,131 @@ +# All the vocoders used in diffusions pipelines will be implemented here. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..modeling_utils import ModelMixin +from ..configuration_utils import ConfigMixin + +# DiffSound Uses MelGAN +class MelGAN(nn.Module): + def __init__( + self, + ): + super().__init__() + return + + +class CausalConv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1) + + def forward(self, x): + return self._conv_forward(F.pad(x, [self.causal_padding, 0]), self.weight, self.bias) + + +class CausalConvTranspose1d(nn.ConvTranspose1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.causal_padding = ( + self.dilation[0] * (self.kernel_size[0] - 1) + self.output_padding[0] + 1 - self.stride[0] + ) + + def forward(self, x, output_size=None): + if self.padding_mode != "zeros": + raise ValueError("Only `zeros` padding mode is supported for ConvTranspose1d") + + assert isinstance(self.padding, tuple) + output_padding = self._output_padding( + x, output_size, self.stride, self.padding, self.kernel_size, self.dilation + ) + return F.conv_transpose1d( + x, self.weight, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation + )[..., : -self.causal_padding] + + +class SoundStreamResNet(nn.Module): + def __init__(self, in_channels, out_channels, dilation): + super().__init__() + self.dilation = dilation + self.causal_conv = nn.CausalConv1d(in_channels, out_channels, kernel_size=7, dilation=dilation) + self.conv_1d = nn.Conv1d(in_channels, out_channels, kernel_size=1) + self.act = nn.ELU() + + def forward(self, hidden_states): + residuals = hidden_states + hidden_states = self.causal_conv(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.conv_1d(hidden_states) + return residuals + hidden_states + + +class SoundStreamDecoderBlock(nn.Module): + def __init__(self, out_channels, stride): + super().__init__() + self.project_in = CausalConvTranspose1d( + in_channels=2 * out_channels, out_channels=out_channels, kernel_size=2 * stride, stride=stride + ) + self.act = nn.ELU() + + self.resnet_blocks = nn.ModuleList( + [SoundStreamResNet(out_channels, out_channels, 512, dilation=3 ^ rate) for rate in range(3)] + ) + + def forward(self, hidden_states): + hidden_states = self.project_in(hidden_states) + hidden_states = self.act(hidden_states) + for resnet in self.resnet_blocks: + hidden_states = resnet(hidden_states) + return hidden_states + + +# notes2audio uses SoundStream +class SoundStreamVocoder(ModelMixin, ConfigMixin): + """Residual VQ VAE model from `SoundStream: An End-to-End Neural Audio Codec` + + Args: + in_channels (`int`): number of input channels. It corresponds to the number of spectrogram features + that are passed to the decoder to compute the raw audio. + ConfigMixin (_type_): _description_ + """ + + def __init__(self, in_channels=8, out_channels=1, strides=[8, 5, 4, 2], channel_factors=[8, 4, 2, 1]): + super().__init__() + self.act = nn.ELU() + self.bottleneck = CausalConv1d(in_channels=in_channels, out_channels=16 * out_channels, kernel_size=7) + self.decoder_blocks = nn.ModuleList( + SoundStreamDecoderBlock(out_channels=out_channels * channel_factors[i], stride=strides[i]) + for i in range(4) + ) + self.last_layer_conv = CausalConv1d(in_channels=out_channels, out_channels=1, kernel_size=7) + return + + def decode(self, features): + """Decodes features to audio. + Args: + features: Mel spectrograms, shape [batch, n_frames, n_dims]. + Returns: + audio: Shape [batch, n_frames * hop_size] + """ + if self._decode_dither_amount > 0: + features += torch.random.normal(size=features.shape) * self._decode_dither_amount + + hidden_states = self.bottleneck(features) + hidden_states = self.act(hidden_states) + for layer in self.decoder_blocks: + hidden_states = layer(hidden_states) + hidden_states = self.act(hidden_states) + + audio = self.last_layer_conv(hidden_states) + + return audio + + +# TODO @Arthur DiffSinger uses this as vocoder +class HiFiGAN(nn.Module): + def __init__( + self, + ): + super().__init__() + return From f4373973e69939b63ff096829c4cacb9f8e2e9a2 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sat, 17 Sep 2022 13:06:09 +0200 Subject: [PATCH 2/7] draft pipeline --- .../notes2audio/pipeline_notes2audio.py | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 src/diffusers/pipelines/notes2audio/pipeline_notes2audio.py diff --git a/src/diffusers/pipelines/notes2audio/pipeline_notes2audio.py b/src/diffusers/pipelines/notes2audio/pipeline_notes2audio.py new file mode 100644 index 000000000000..bca897e1c570 --- /dev/null +++ b/src/diffusers/pipelines/notes2audio/pipeline_notes2audio.py @@ -0,0 +1,58 @@ +from typing import List, Optional, Union + +import torch + +from transformers import T5Model + +from ...models import Notes2AudioModel, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline + + +class Notes2AudioPipeline(DiffusionPipeline): + r""" + Pipeline for notes(midi)-to-audio generation using music-spectrogram diffusion introduced by magenta in + notes2audio. + + Args: + decoder ([` `]): + Decoder model used to convert the hidden states to a mel spectrogram. Should take as an input the encoder + hidden states as well a the diffusion noise. Should be the soundstreal MELGan style decoder + context_encoder ([` `]): + Encoder used to create the context to smooth the transitions between adjacent audio frames. + note_encoder (` `): + model used to encode ?? + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `decoder` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__(self, spectrogram_decoder, context_encoder, note_encoder, vocoder, scheduler): + self.spectrogram_decoder = spectrogram_decoder + self.context_encoder = context_encoder + self.note_encoder = note_encoder + self.vocoder = vocoder + self.scheduler = scheduler + scheduler = scheduler.set_format("pt") + self.register_modules( + spectrogram_decoder=spectrogram_decoder, + context_encoder=context_encoder, + note_encoder=note_encoder, + scheduler=scheduler, + vocoder=vocoder + ) + + @torch.no_grad() + def __call__( + self, + midi: Union[str, List[str]], + audio_length: int, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ): + return From 0eab64bae5ad1b88f2ec82eec3693427e9bf7fcd Mon Sep 17 00:00:00 2001 From: Arthur Date: Sat, 17 Sep 2022 13:06:27 +0200 Subject: [PATCH 3/7] README --- src/diffusers/pipelines/notes2audio/README.md | 17 +++++++++++++++++ src/diffusers/pipelines/notes2audio/__init__.py | 0 2 files changed, 17 insertions(+) create mode 100644 src/diffusers/pipelines/notes2audio/README.md create mode 100644 src/diffusers/pipelines/notes2audio/__init__.py diff --git a/src/diffusers/pipelines/notes2audio/README.md b/src/diffusers/pipelines/notes2audio/README.md new file mode 100644 index 000000000000..294c867beaf1 --- /dev/null +++ b/src/diffusers/pipelines/notes2audio/README.md @@ -0,0 +1,17 @@ +# TODO Follow the stable diffusion pipeline card + + +Goal of the the implementation : + +```python + +from diffusers import DiffusionPipeline + +pipeline = DiffusionPipeline.from_pretrained("magenta/notes2audio_base_with_context") + +midi_setup_file = "path/to/midi_file.midi" +pipeline(midi_setup_file).sample[0] + + + +``` diff --git a/src/diffusers/pipelines/notes2audio/__init__.py b/src/diffusers/pipelines/notes2audio/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 From b04cdbc02f53da27ad78b618ff3d6e8e534de4e8 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sat, 17 Sep 2022 13:06:52 +0200 Subject: [PATCH 4/7] draft conversion script and music transformer --- scripts/convert_notes2audio.py | 69 ++++++++++ src/diffusers/models/music_transformer.py | 147 ++++++++++++++++++++++ 2 files changed, 216 insertions(+) create mode 100644 scripts/convert_notes2audio.py create mode 100644 src/diffusers/models/music_transformer.py diff --git a/scripts/convert_notes2audio.py b/scripts/convert_notes2audio.py new file mode 100644 index 000000000000..1971a349bc8c --- /dev/null +++ b/scripts/convert_notes2audio.py @@ -0,0 +1,69 @@ +import math +import tensorflow as tf +import tensorflow_datasets as tfds +import tensorflow_hub as hub + +module = hub.KerasLayer('https://tfhub.dev/google/soundstream/mel/decoder/music/1') + +# 1. Convert the TF weights of SOUNDSTREAM to PyTorch +# This will give us the necessary vocoder + + + +# 2. Convert JAX T5 weights to Pytorch using the transformers script +# This will give us the necessary encoder and decoder +# Then encoder corresponds to the note encoder and the decoder part is the spectrogram decoder + +# 3. Convert eh Context Encoder weights to Pytorch +# The context encoder should be pretty straightforward to convert + +# 4. Implement tests to make sure that the models work properly + + +SAMPLE_RATE = 16000 +N_FFT = 1024 +HOP_LENGTH = 320 +WIN_LENGTH = 640 +N_MEL_CHANNELS = 128 +MEL_FMIN = 0.0 +MEL_FMAX = int(SAMPLE_RATE // 2) +CLIP_VALUE_MIN = 1e-5 +CLIP_VALUE_MAX = 1e8 + +MEL_BASIS = tf.signal.linear_to_mel_weight_matrix( + num_mel_bins=N_MEL_CHANNELS, + num_spectrogram_bins=N_FFT // 2 + 1, + sample_rate=SAMPLE_RATE, + lower_edge_hertz=MEL_FMIN, + upper_edge_hertz=MEL_FMAX) + +def calculate_spectrogram(samples): + """Calculate mel spectrogram using the parameters the model expects.""" + fft = tf.signal.stft( + samples, + frame_length=WIN_LENGTH, + frame_step=HOP_LENGTH, + fft_length=N_FFT, + window_fn=tf.signal.hann_window, + pad_end=True) + fft_modulus = tf.abs(fft) + + output = tf.matmul(fft_modulus, MEL_BASIS) + + output = tf.clip_by_value( + output, + clip_value_min=CLIP_VALUE_MIN, + clip_value_max=CLIP_VALUE_MAX) + output = tf.math.log(output) + return output + +# Load a music sample from the GTZAN dataset. +gtzan = tfds.load('gtzan', split='train') +# Convert an example from int to float. +samples = tf.cast(next(iter(gtzan))['audio'] / 32768, dtype=tf.float32) +# Add batch dimension. +samples = tf.expand_dims(samples, axis=0) +# Compute a mel-spectrogram. +spectrogram = calculate_spectrogram(samples) +# Reconstruct the audio from a mel-spectrogram using a SoundStream decoder. +reconstructed_samples = module(spectrogram) \ No newline at end of file diff --git a/src/diffusers/models/music_transformer.py b/src/diffusers/models/music_transformer.py new file mode 100644 index 000000000000..8aeb3451fb9a --- /dev/null +++ b/src/diffusers/models/music_transformer.py @@ -0,0 +1,147 @@ +# This file will contain the necessary class to build the notes2audio pipeline +# Note Encoder, Spectrogram Decoder and Context Encoder + + + +import torch +import torch.nn as nn + +from transformers.models.t5.modeling_t5 import T5Stack, T5Block +from transformers import T5Config + +class ContextEncoder(nn.Module): + def __init__(self) -> None: + super().__init__() + +class NoteEncoder(nn.Module): + def __init__(self) -> None: + super().__init__() + +class SpectrogramDecoder(nn.Module): + def __init__(self) -> None: + super().__init__() + + +class TokenEncoder(nn.Module): + """A stack of encoder layers.""" + config: T5Config + + def __call__(self, + encoder_input_tokens, + encoder_inputs_mask, + deterministic): + cfg = self.config + + assert encoder_input_tokens.ndim == 2 # [batch, length] + + seq_length = encoder_input_tokens.shape[1] + inputs_positions = jnp.arange(seq_length)[None, :] + + # [batch, length] -> [batch, length, emb_dim] + x = layers.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + dtype=cfg.dtype, + embedding_init=nn.initializers.normal(stddev=1.0), + one_hot=True, + name='token_embedder')(encoder_input_tokens.astype('int32')) + + x += position_encoding_layer(config=cfg, max_length=seq_length)( + inputs_positions) + x = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + x, deterministic=deterministic) + x = x.astype(cfg.dtype) + + for lyr in range(cfg.num_encoder_layers): + # [batch, length, emb_dim] -> [batch, length, emb_dim] + x = EncoderLayer( + config=cfg, + name=f'layers_{lyr}')( + inputs=x, + encoder_inputs_mask=encoder_inputs_mask, + deterministic=deterministic) + x = layers.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) + x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic) + return x, encoder_inputs_mask + + +class ContinuousContextTransformer(nn.Module): + """An encoder-decoder Transformer model with a second audio context encoder.""" + config: T5Config + + def setup(self): + cfg = self.config + + self.token_encoder = TokenEncoder(config=cfg) + self.continuous_encoder = ContinuousEncoder(config=cfg) + self.decoder = Decoder(config=cfg) + + def encode(self, + input_tokens, + continuous_inputs, + continuous_mask, + enable_dropout=True): + """Applies Transformer encoder-branch on the inputs.""" + assert input_tokens.ndim == 2 # (batch, length) + assert continuous_inputs.ndim == 3 # (batch, length, input_dims) + + tokens_mask = input_tokens > 0 + + tokens_encoded, tokens_mask = self.token_encoder( + encoder_input_tokens=input_tokens, + encoder_inputs_mask=tokens_mask, + deterministic=not enable_dropout) + + continuous_encoded, continuous_mask = self.continuous_encoder( + encoder_inputs=continuous_inputs, + encoder_inputs_mask=continuous_mask, + deterministic=not enable_dropout) + + return [(tokens_encoded, tokens_mask), + (continuous_encoded, continuous_mask)] + + def decode( + self, + encodings_and_masks, + input_tokens, + noise_time, + enable_dropout=True): + """Applies Transformer decoder-branch on encoded-input and target.""" + logits = self.decoder( + encodings_and_masks=encodings_and_masks, + decoder_input_tokens=input_tokens, + decoder_noise_time=noise_time, + deterministic=not enable_dropout) + return logits.astype(self.config.dtype) + + def __call__(self, + encoder_input_tokens, + encoder_continuous_inputs, + encoder_continuous_mask, + decoder_input_tokens, + decoder_noise_time, + *, + enable_dropout: bool = True): + """Applies Transformer model on the inputs. + Args: + encoder_input_tokens: input data to the encoder. + encoder_continuous_inputs: continuous inputs for the second encoder. + encoder_continuous_mask: mask for continuous inputs. + decoder_input_tokens: input token to the decoder. + decoder_noise_time: noise continuous time for diffusion. + enable_dropout: Ensables dropout if set to True. + Returns: + logits array from full transformer. + """ + encodings_and_masks = self.encode( + input_tokens=encoder_input_tokens, + continuous_inputs=encoder_continuous_inputs, + continuous_mask=encoder_continuous_mask, + enable_dropout=enable_dropout) + + return self.decode( + encodings_and_masks=encodings_and_masks, + input_tokens=decoder_input_tokens, + noise_time=decoder_noise_time, + enable_dropout=enable_dropout) \ No newline at end of file From baf39aabb7f0629ebfe941f569ea123c9d6b0df7 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sat, 17 Sep 2022 17:17:48 +0200 Subject: [PATCH 5/7] add film layer --- src/diffusers/models/music_transformer.py | 28 +++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/diffusers/models/music_transformer.py b/src/diffusers/models/music_transformer.py index 8aeb3451fb9a..07be51286391 100644 --- a/src/diffusers/models/music_transformer.py +++ b/src/diffusers/models/music_transformer.py @@ -9,6 +9,34 @@ from transformers.models.t5.modeling_t5 import T5Stack, T5Block from transformers import T5Config +class FiLMLayer(nn.Module): + """A simple FiLM layer for conditioning on the diffusion time embedding. + + """ + + def __init__(self, in_channels, out_channels) -> None: + super().__init__() + self.gamma = nn.Linear(in_channels, out_channels) # s + self.beta = nn.Linear(in_channels, out_channels) # t + + def forward(self, hidden_states, conditioning_emb): + """Updates the hidden states based on the conditioning embeddings. + + Args: + hidden_states (`Tensor`): _description_ + conditioning_emb (`Tensor`): _description_ + + Returns: + _type_: _description_ + """ + + beta = self.beta(conditioning_emb).unsqueeze(-1).unsqueeze(-1) + gamma = self.gamma(conditioning_emb).unsqueeze(-1).unsqueeze(-1) + + hidden_states = hidden_states * (gamma + 1.0) + beta + return hidden_states + + class ContextEncoder(nn.Module): def __init__(self) -> None: super().__init__() From c5b1620b3525a2c64a82124fbdec1993ec549745 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sat, 17 Sep 2022 17:18:09 +0200 Subject: [PATCH 6/7] style --- scripts/convert_notes2audio.py | 56 ++-- src/diffusers/models/music_transformer.py | 263 +++++++++--------- src/diffusers/models/vocoders.py | 3 +- .../notes2audio/pipeline_notes2audio.py | 2 +- 4 files changed, 157 insertions(+), 167 deletions(-) diff --git a/scripts/convert_notes2audio.py b/scripts/convert_notes2audio.py index 1971a349bc8c..b989247b17e4 100644 --- a/scripts/convert_notes2audio.py +++ b/scripts/convert_notes2audio.py @@ -1,23 +1,24 @@ import math + import tensorflow as tf import tensorflow_datasets as tfds import tensorflow_hub as hub -module = hub.KerasLayer('https://tfhub.dev/google/soundstream/mel/decoder/music/1') + +module = hub.KerasLayer("https://tfhub.dev/google/soundstream/mel/decoder/music/1") # 1. Convert the TF weights of SOUNDSTREAM to PyTorch # This will give us the necessary vocoder - -# 2. Convert JAX T5 weights to Pytorch using the transformers script +# 2. Convert JAX T5 weights to Pytorch using the transformers script # This will give us the necessary encoder and decoder # Then encoder corresponds to the note encoder and the decoder part is the spectrogram decoder -# 3. Convert eh Context Encoder weights to Pytorch +# 3. Convert eh Context Encoder weights to Pytorch # The context encoder should be pretty straightforward to convert -# 4. Implement tests to make sure that the models work properly +# 4. Implement tests to make sure that the models work properly SAMPLE_RATE = 16000 @@ -35,35 +36,36 @@ num_spectrogram_bins=N_FFT // 2 + 1, sample_rate=SAMPLE_RATE, lower_edge_hertz=MEL_FMIN, - upper_edge_hertz=MEL_FMAX) + upper_edge_hertz=MEL_FMAX, +) + def calculate_spectrogram(samples): - """Calculate mel spectrogram using the parameters the model expects.""" - fft = tf.signal.stft( - samples, - frame_length=WIN_LENGTH, - frame_step=HOP_LENGTH, - fft_length=N_FFT, - window_fn=tf.signal.hann_window, - pad_end=True) - fft_modulus = tf.abs(fft) - - output = tf.matmul(fft_modulus, MEL_BASIS) - - output = tf.clip_by_value( - output, - clip_value_min=CLIP_VALUE_MIN, - clip_value_max=CLIP_VALUE_MAX) - output = tf.math.log(output) - return output + """Calculate mel spectrogram using the parameters the model expects.""" + fft = tf.signal.stft( + samples, + frame_length=WIN_LENGTH, + frame_step=HOP_LENGTH, + fft_length=N_FFT, + window_fn=tf.signal.hann_window, + pad_end=True, + ) + fft_modulus = tf.abs(fft) + + output = tf.matmul(fft_modulus, MEL_BASIS) + + output = tf.clip_by_value(output, clip_value_min=CLIP_VALUE_MIN, clip_value_max=CLIP_VALUE_MAX) + output = tf.math.log(output) + return output + # Load a music sample from the GTZAN dataset. -gtzan = tfds.load('gtzan', split='train') +gtzan = tfds.load("gtzan", split="train") # Convert an example from int to float. -samples = tf.cast(next(iter(gtzan))['audio'] / 32768, dtype=tf.float32) +samples = tf.cast(next(iter(gtzan))["audio"] / 32768, dtype=tf.float32) # Add batch dimension. samples = tf.expand_dims(samples, axis=0) # Compute a mel-spectrogram. spectrogram = calculate_spectrogram(samples) # Reconstruct the audio from a mel-spectrogram using a SoundStream decoder. -reconstructed_samples = module(spectrogram) \ No newline at end of file +reconstructed_samples = module(spectrogram) diff --git a/src/diffusers/models/music_transformer.py b/src/diffusers/models/music_transformer.py index 07be51286391..ebb258a1787e 100644 --- a/src/diffusers/models/music_transformer.py +++ b/src/diffusers/models/music_transformer.py @@ -2,25 +2,23 @@ # Note Encoder, Spectrogram Decoder and Context Encoder - -import torch +import torch import torch.nn as nn -from transformers.models.t5.modeling_t5 import T5Stack, T5Block from transformers import T5Config +from transformers.models.t5.modeling_t5 import T5Block, T5Stack + class FiLMLayer(nn.Module): - """A simple FiLM layer for conditioning on the diffusion time embedding. - - """ + """A simple FiLM layer for conditioning on the diffusion time embedding.""" def __init__(self, in_channels, out_channels) -> None: super().__init__() - self.gamma = nn.Linear(in_channels, out_channels) # s - self.beta = nn.Linear(in_channels, out_channels) # t - + self.gamma = nn.Linear(in_channels, out_channels) # s + self.beta = nn.Linear(in_channels, out_channels) # t + def forward(self, hidden_states, conditioning_emb): - """Updates the hidden states based on the conditioning embeddings. + """Updates the hidden states based on the conditioning embeddings. Args: hidden_states (`Tensor`): _description_ @@ -29,147 +27,136 @@ def forward(self, hidden_states, conditioning_emb): Returns: _type_: _description_ """ - + beta = self.beta(conditioning_emb).unsqueeze(-1).unsqueeze(-1) gamma = self.gamma(conditioning_emb).unsqueeze(-1).unsqueeze(-1) - + hidden_states = hidden_states * (gamma + 1.0) + beta return hidden_states class ContextEncoder(nn.Module): def __init__(self) -> None: - super().__init__() - + super().__init__() + + class NoteEncoder(nn.Module): def __init__(self) -> None: - super().__init__() - + super().__init__() + + class SpectrogramDecoder(nn.Module): def __init__(self) -> None: - super().__init__() - - + super().__init__() + + class TokenEncoder(nn.Module): - """A stack of encoder layers.""" - config: T5Config - - def __call__(self, - encoder_input_tokens, - encoder_inputs_mask, - deterministic): - cfg = self.config - - assert encoder_input_tokens.ndim == 2 # [batch, length] - - seq_length = encoder_input_tokens.shape[1] - inputs_positions = jnp.arange(seq_length)[None, :] - - # [batch, length] -> [batch, length, emb_dim] - x = layers.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - dtype=cfg.dtype, - embedding_init=nn.initializers.normal(stddev=1.0), - one_hot=True, - name='token_embedder')(encoder_input_tokens.astype('int32')) - - x += position_encoding_layer(config=cfg, max_length=seq_length)( - inputs_positions) - x = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - x, deterministic=deterministic) - x = x.astype(cfg.dtype) - - for lyr in range(cfg.num_encoder_layers): - # [batch, length, emb_dim] -> [batch, length, emb_dim] - x = EncoderLayer( - config=cfg, - name=f'layers_{lyr}')( - inputs=x, - encoder_inputs_mask=encoder_inputs_mask, - deterministic=deterministic) - x = layers.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) - x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic) - return x, encoder_inputs_mask + """A stack of encoder layers.""" + + config: T5Config + + def __call__(self, encoder_input_tokens, encoder_inputs_mask, deterministic): + cfg = self.config + + assert encoder_input_tokens.ndim == 2 # [batch, length] + + seq_length = encoder_input_tokens.shape[1] + inputs_positions = jnp.arange(seq_length)[None, :] + + # [batch, length] -> [batch, length, emb_dim] + x = layers.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + dtype=cfg.dtype, + embedding_init=nn.initializers.normal(stddev=1.0), + one_hot=True, + name="token_embedder", + )(encoder_input_tokens.astype("int32")) + + x += position_encoding_layer(config=cfg, max_length=seq_length)(inputs_positions) + x = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(x, deterministic=deterministic) + x = x.astype(cfg.dtype) + + for lyr in range(cfg.num_encoder_layers): + # [batch, length, emb_dim] -> [batch, length, emb_dim] + x = EncoderLayer(config=cfg, name=f"layers_{lyr}")( + inputs=x, encoder_inputs_mask=encoder_inputs_mask, deterministic=deterministic + ) + x = layers.LayerNorm(dtype=cfg.dtype, name="encoder_norm")(x) + x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic) + return x, encoder_inputs_mask class ContinuousContextTransformer(nn.Module): - """An encoder-decoder Transformer model with a second audio context encoder.""" - config: T5Config - - def setup(self): - cfg = self.config - - self.token_encoder = TokenEncoder(config=cfg) - self.continuous_encoder = ContinuousEncoder(config=cfg) - self.decoder = Decoder(config=cfg) - - def encode(self, - input_tokens, - continuous_inputs, - continuous_mask, - enable_dropout=True): - """Applies Transformer encoder-branch on the inputs.""" - assert input_tokens.ndim == 2 # (batch, length) - assert continuous_inputs.ndim == 3 # (batch, length, input_dims) - - tokens_mask = input_tokens > 0 - - tokens_encoded, tokens_mask = self.token_encoder( - encoder_input_tokens=input_tokens, - encoder_inputs_mask=tokens_mask, - deterministic=not enable_dropout) - - continuous_encoded, continuous_mask = self.continuous_encoder( - encoder_inputs=continuous_inputs, - encoder_inputs_mask=continuous_mask, - deterministic=not enable_dropout) - - return [(tokens_encoded, tokens_mask), - (continuous_encoded, continuous_mask)] - - def decode( - self, - encodings_and_masks, - input_tokens, - noise_time, - enable_dropout=True): - """Applies Transformer decoder-branch on encoded-input and target.""" - logits = self.decoder( - encodings_and_masks=encodings_and_masks, - decoder_input_tokens=input_tokens, - decoder_noise_time=noise_time, - deterministic=not enable_dropout) - return logits.astype(self.config.dtype) - - def __call__(self, - encoder_input_tokens, - encoder_continuous_inputs, - encoder_continuous_mask, - decoder_input_tokens, - decoder_noise_time, - *, - enable_dropout: bool = True): - """Applies Transformer model on the inputs. - Args: - encoder_input_tokens: input data to the encoder. - encoder_continuous_inputs: continuous inputs for the second encoder. - encoder_continuous_mask: mask for continuous inputs. - decoder_input_tokens: input token to the decoder. - decoder_noise_time: noise continuous time for diffusion. - enable_dropout: Ensables dropout if set to True. - Returns: - logits array from full transformer. - """ - encodings_and_masks = self.encode( - input_tokens=encoder_input_tokens, - continuous_inputs=encoder_continuous_inputs, - continuous_mask=encoder_continuous_mask, - enable_dropout=enable_dropout) - - return self.decode( - encodings_and_masks=encodings_and_masks, - input_tokens=decoder_input_tokens, - noise_time=decoder_noise_time, - enable_dropout=enable_dropout) \ No newline at end of file + """An encoder-decoder Transformer model with a second audio context encoder.""" + + config: T5Config + + def setup(self): + cfg = self.config + + self.token_encoder = TokenEncoder(config=cfg) + self.continuous_encoder = ContinuousEncoder(config=cfg) + self.decoder = Decoder(config=cfg) + + def encode(self, input_tokens, continuous_inputs, continuous_mask, enable_dropout=True): + """Applies Transformer encoder-branch on the inputs.""" + assert input_tokens.ndim == 2 # (batch, length) + assert continuous_inputs.ndim == 3 # (batch, length, input_dims) + + tokens_mask = input_tokens > 0 + + tokens_encoded, tokens_mask = self.token_encoder( + encoder_input_tokens=input_tokens, encoder_inputs_mask=tokens_mask, deterministic=not enable_dropout + ) + + continuous_encoded, continuous_mask = self.continuous_encoder( + encoder_inputs=continuous_inputs, encoder_inputs_mask=continuous_mask, deterministic=not enable_dropout + ) + + return [(tokens_encoded, tokens_mask), (continuous_encoded, continuous_mask)] + + def decode(self, encodings_and_masks, input_tokens, noise_time, enable_dropout=True): + """Applies Transformer decoder-branch on encoded-input and target.""" + logits = self.decoder( + encodings_and_masks=encodings_and_masks, + decoder_input_tokens=input_tokens, + decoder_noise_time=noise_time, + deterministic=not enable_dropout, + ) + return logits.astype(self.config.dtype) + + def __call__( + self, + encoder_input_tokens, + encoder_continuous_inputs, + encoder_continuous_mask, + decoder_input_tokens, + decoder_noise_time, + *, + enable_dropout: bool = True, + ): + """Applies Transformer model on the inputs. + Args: + encoder_input_tokens: input data to the encoder. + encoder_continuous_inputs: continuous inputs for the second encoder. + encoder_continuous_mask: mask for continuous inputs. + decoder_input_tokens: input token to the decoder. + decoder_noise_time: noise continuous time for diffusion. + enable_dropout: Ensables dropout if set to True. + Returns: + logits array from full transformer. + """ + encodings_and_masks = self.encode( + input_tokens=encoder_input_tokens, + continuous_inputs=encoder_continuous_inputs, + continuous_mask=encoder_continuous_mask, + enable_dropout=enable_dropout, + ) + + return self.decode( + encodings_and_masks=encodings_and_masks, + input_tokens=decoder_input_tokens, + noise_time=decoder_noise_time, + enable_dropout=enable_dropout, + ) diff --git a/src/diffusers/models/vocoders.py b/src/diffusers/models/vocoders.py index ea2a616463ad..ac44c853b73a 100644 --- a/src/diffusers/models/vocoders.py +++ b/src/diffusers/models/vocoders.py @@ -3,8 +3,9 @@ import torch.nn as nn import torch.nn.functional as F -from ..modeling_utils import ModelMixin from ..configuration_utils import ConfigMixin +from ..modeling_utils import ModelMixin + # DiffSound Uses MelGAN class MelGAN(nn.Module): diff --git a/src/diffusers/pipelines/notes2audio/pipeline_notes2audio.py b/src/diffusers/pipelines/notes2audio/pipeline_notes2audio.py index bca897e1c570..651248b28612 100644 --- a/src/diffusers/pipelines/notes2audio/pipeline_notes2audio.py +++ b/src/diffusers/pipelines/notes2audio/pipeline_notes2audio.py @@ -38,7 +38,7 @@ def __init__(self, spectrogram_decoder, context_encoder, note_encoder, vocoder, context_encoder=context_encoder, note_encoder=note_encoder, scheduler=scheduler, - vocoder=vocoder + vocoder=vocoder, ) @torch.no_grad() From d265f70fc5a1246473c21632379f39fbc815bff8 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sat, 17 Sep 2022 22:03:39 +0200 Subject: [PATCH 7/7] Update, a note tokenizer will be required --- src/diffusers/models/music_transformer.py | 5 +++++ .../notes2audio/pipeline_notes2audio.py | 21 +++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/src/diffusers/models/music_transformer.py b/src/diffusers/models/music_transformer.py index ebb258a1787e..6768bfdb5ca7 100644 --- a/src/diffusers/models/music_transformer.py +++ b/src/diffusers/models/music_transformer.py @@ -40,6 +40,11 @@ def __init__(self) -> None: super().__init__() +class NoteTokenizer(nn.Module): + def __init__(self) -> None: + super().__init__() + + class NoteEncoder(nn.Module): def __init__(self) -> None: super().__init__() diff --git a/src/diffusers/pipelines/notes2audio/pipeline_notes2audio.py b/src/diffusers/pipelines/notes2audio/pipeline_notes2audio.py index 651248b28612..9703b9532588 100644 --- a/src/diffusers/pipelines/notes2audio/pipeline_notes2audio.py +++ b/src/diffusers/pipelines/notes2audio/pipeline_notes2audio.py @@ -1,3 +1,4 @@ +from turtle import forward from typing import List, Optional, Union import torch @@ -41,6 +42,26 @@ def __init__(self, spectrogram_decoder, context_encoder, note_encoder, vocoder, vocoder=vocoder, ) + def generation_step(self): + """ + Generate a single frame of audio which corresponds to 5 seconds. + + Args: + encoder_continuous_inputs (`torch.Tensor`): fields for context + encoder_continuous_mask (`torch.Tensor`): fields for context + encoder_input_tokens (`torch.Tensor`): fields for context + decoder_target_tokens (`torch.Tensor`): fields for context + diffusion_noise (`torch.Tensor`): fields for context + diffusion_noise_mask (`torch.Tensor`): fields for context + deterministic (`bool`): fields for context + **kwargs (`dict`): fields for context + + Returns: + `torch.Tensor`: The generated audio + + + """ + @torch.no_grad() def __call__( self,