From 8de409926c15eba5e94518ec64634920b94cbab5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 3 Jul 2023 14:08:48 +0000 Subject: [PATCH 1/5] Improve memory text to video --- src/diffusers/models/attention.py | 21 ++++++++++- src/diffusers/models/attention_processor.py | 1 + src/diffusers/models/unet_3d_condition.py | 37 +++++++++++++++++++ .../pipeline_text_to_video_synth.py | 3 ++ .../pipeline_text_to_video_synth_img2img.py | 3 ++ 5 files changed, 63 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 8805257ebe9a..e6283d9d5108 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -119,6 +119,15 @@ def __init__( self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + # let chunk size default to None + self._chunk_size = None + self._chuck_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + def forward( self, hidden_states: torch.FloatTensor, @@ -126,7 +135,6 @@ def forward( encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, timestep: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, ): # Notice that normalization is always applied before the real computation in the following blocks. @@ -141,6 +149,7 @@ def forward( norm_hidden_states = self.norm1(hidden_states) cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, @@ -171,7 +180,15 @@ def forward( if self.use_ada_layer_norm_zero: norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - ff_output = self.ff(norm_hidden_states) + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError(f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to change `self._chunk_size` accordingly by calling `unet.enable_forward_chunking`.") + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat([self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], dim=self._chunk_dim) + else: + ff_output = self.ff(norm_hidden_states) if self.use_ada_layer_norm_zero: ff_output = gate_mlp.unsqueeze(1) * ff_output diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 0bc7886c2653..af724b4cf600 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -161,6 +161,7 @@ def __init__( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) + def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 36dcaf21f827..2a642c5f4782 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -389,6 +389,43 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) + def enable_forward_chunking(self, chunk_size=None, dim=0): + """ + Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + def disable_forward_chunking(self): + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, dim) + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index e30f183808a5..680a524732e9 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -634,6 +634,9 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 6.1 Chunk feed-forward computation to save memory + self.unet.enable_forward_chunking(chunk_size=1, dim=1) + # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index ce5109a58213..1b6cd9c2b392 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -709,6 +709,9 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 6.1 Chunk feed-forward computation to save memory + self.unet.enable_forward_chunking(chunk_size=1, dim=1) + # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: From 67fad2781c83bbf4f8cc74bdd6e294f302f72e0a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 3 Jul 2023 16:10:01 +0200 Subject: [PATCH 2/5] Apply suggestions from code review --- src/diffusers/models/attention.py | 1 + src/diffusers/models/attention_processor.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index e6283d9d5108..8dc0319dea20 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -135,6 +135,7 @@ def forward( encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, ): # Notice that normalization is always applied before the real computation in the following blocks. diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index af724b4cf600..0bc7886c2653 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -161,7 +161,6 @@ def __init__( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) - def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None From 2f3a9a66ae48b5bb23772acd029780e0ce4fc069 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 3 Jul 2023 14:18:19 +0000 Subject: [PATCH 3/5] add test --- src/diffusers/models/attention.py | 9 +++++++-- src/diffusers/models/unet_3d_condition.py | 11 +++++++---- tests/models/test_models_unet_3d_condition.py | 18 ++++++++++++++++++ 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 8dc0319dea20..6c4d7f71868b 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -184,10 +184,15 @@ def forward( if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: - raise ValueError(f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to change `self._chunk_size` accordingly by calling `unet.enable_forward_chunking`.") + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to change `self._chunk_size` accordingly by calling `unet.enable_forward_chunking`." + ) num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size - ff_output = torch.cat([self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], dim=self._chunk_dim) + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) else: ff_output = self.ff(norm_hidden_states) diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 2a642c5f4782..9bc89c571c52 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -391,13 +391,16 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def enable_forward_chunking(self, chunk_size=None, dim=0): """ - Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). Parameters: chunk_size (`int`, *optional*): - The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually over each tensor of dim=`dim`. + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. dim (`int`, *optional*, defaults to `0`): - The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) or dim=1 (sequence length). + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). """ if dim not in [0, 1]: raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") @@ -424,7 +427,7 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): - fn_recursive_feed_forward(module, None, dim) + fn_recursive_feed_forward(module, None, 0) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index 3f29d0a41e18..72a33854bdcd 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -399,5 +399,23 @@ def test_lora_xformers_on_off(self): assert (sample - on_sample).abs().max() < 1e-4 assert (sample - off_sample).abs().max() < 1e-4 + def test_feed_forward_chunking(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict["norm_num_groups"] = 32 + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict)[0] + + model.enable_forward_chunking() + with torch.no_grad(): + output_2 = model(**inputs_dict)[0] + + self.assertEqual(output.shape, output_2.shape, "Shape doesn't match") + assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2 + # (todo: sayakpaul) implement SLOW tests. From 3916c4dfadafb96913520cb2433d98a8a27a10f8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 3 Jul 2023 16:58:48 +0200 Subject: [PATCH 4/5] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- src/diffusers/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 6c4d7f71868b..51212e6147b0 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -185,7 +185,7 @@ def forward( # "feed_forward_chunk_size" can be used to save memory if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: raise ValueError( - f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to change `self._chunk_size` accordingly by calling `unet.enable_forward_chunking`." + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." ) num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size From f418e06c3615312f47eb4f64e5831c37b0e82119 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 3 Jul 2023 15:59:24 +0000 Subject: [PATCH 5/5] finish test setup --- src/diffusers/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 6c4d7f71868b..c3e99af5c365 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -121,7 +121,7 @@ def __init__( # let chunk size default to None self._chunk_size = None - self._chuck_dim = 0 + self._chunk_dim = 0 def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): # Sets chunk feed-forward