From 61848154935c50f9f5a33907123247e01495e83e Mon Sep 17 00:00:00 2001 From: muhammad_hanif Date: Tue, 31 Jan 2023 09:02:43 +0700 Subject: [PATCH 01/17] add use_memory_efficient params placeholder --- src/diffusers/models/attention_flax.py | 26 +++++++++++++++++-- src/diffusers/models/unet_2d_blocks_flax.py | 12 +++++++++ .../models/unet_2d_condition_flax.py | 6 +++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 71106e05452c..f902c3f9af96 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -31,6 +31,8 @@ class FlaxAttentionBlock(nn.Module): Dropout rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + use_memory_efficient (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 """ query_dim: int @@ -38,6 +40,7 @@ class FlaxAttentionBlock(nn.Module): dim_head: int = 64 dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 + use_memory_efficient: bool = False def setup(self): inner_dim = self.dim_head * self.heads @@ -108,6 +111,8 @@ class FlaxBasicTransformerBlock(nn.Module): Whether to only apply cross attention. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + use_memory_efficient (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 """ dim: int n_heads: int @@ -115,12 +120,25 @@ class FlaxBasicTransformerBlock(nn.Module): dropout: float = 0.0 only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 + use_memory_efficient: bool = False def setup(self): # self attention (or cross_attention if only_cross_attention is True) - self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.attn1 = FlaxAttentionBlock(self.dim, + self.n_heads, + self.d_head, + self.dropout, + dtype=self.dtype, + use_memory_efficient=self.use_memory_efficient, + ) # cross attention - self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.attn2 = FlaxAttentionBlock(self.dim, + self.n_heads, + self.d_head, + self.dropout, + dtype=self.dtype, + use_memory_efficient=self.use_memory_efficient, + ) self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) @@ -169,6 +187,8 @@ class FlaxTransformer2DModel(nn.Module): only_cross_attention (`bool`, defaults to `False`): tbd dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + use_memory_efficient (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 """ in_channels: int n_heads: int @@ -178,6 +198,7 @@ class FlaxTransformer2DModel(nn.Module): use_linear_projection: bool = False only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 + use_memory_efficient: bool = False def setup(self): self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) @@ -202,6 +223,7 @@ def setup(self): dropout=self.dropout, only_cross_attention=self.only_cross_attention, dtype=self.dtype, + use_memory_efficient=self.use_memory_efficient, ) for _ in range(self.depth) ] diff --git a/src/diffusers/models/unet_2d_blocks_flax.py b/src/diffusers/models/unet_2d_blocks_flax.py index 96e76cb06a59..8d3a1757073d 100644 --- a/src/diffusers/models/unet_2d_blocks_flax.py +++ b/src/diffusers/models/unet_2d_blocks_flax.py @@ -39,6 +39,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module): Whether to add downsampling layer before each final output dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + use_memory_efficient (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 """ in_channels: int out_channels: int @@ -49,6 +51,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): use_linear_projection: bool = False only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 + use_memory_efficient:bool = False def setup(self): resnets = [] @@ -73,6 +76,7 @@ def setup(self): use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, dtype=self.dtype, + use_memory_efficient=self.use_memory_efficient, ) attentions.append(attn_block) @@ -174,6 +178,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module): Whether to add upsampling layer before each final output dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + use_memory_efficient (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 """ in_channels: int out_channels: int @@ -185,6 +191,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): use_linear_projection: bool = False only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 + use_memory_efficient:bool = False def setup(self): resnets = [] @@ -210,6 +217,7 @@ def setup(self): use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, dtype=self.dtype, + use_memory_efficient=self.use_memory_efficient, ) attentions.append(attn_block) @@ -313,6 +321,8 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): Number of attention heads of each spatial transformer block dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + use_memory_efficient (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 """ in_channels: int dropout: float = 0.0 @@ -320,6 +330,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): attn_num_head_channels: int = 1 use_linear_projection: bool = False dtype: jnp.dtype = jnp.float32 + use_memory_efficient:bool = False def setup(self): # there is always at least one resnet @@ -342,6 +353,7 @@ def setup(self): depth=1, use_linear_projection=self.use_linear_projection, dtype=self.dtype, + use_memory_efficient=self.use_memory_efficient, ) attentions.append(attn_block) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 8d8308c5bfb9..c2415a096220 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -88,6 +88,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + use_memory_efficient (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 """ @@ -111,6 +113,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): dtype: jnp.dtype = jnp.float32 flip_sin_to_cos: bool = True freq_shift: int = 0 + use_memory_efficient:bool = False def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: # init input tensors @@ -170,6 +173,7 @@ def setup(self): use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], dtype=self.dtype, + use_memory_efficient=self.use_memory_efficient, ) else: down_block = FlaxDownBlock2D( @@ -191,6 +195,7 @@ def setup(self): attn_num_head_channels=attention_head_dim[-1], use_linear_projection=self.use_linear_projection, dtype=self.dtype, + use_memory_efficient=self.use_memory_efficient, ) # up @@ -218,6 +223,7 @@ def setup(self): use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], dtype=self.dtype, + use_memory_efficient=self.use_memory_efficient, ) else: up_block = FlaxUpBlock2D( From e2d5708c88cee7177dc3d462eab23e1dc83292c1 Mon Sep 17 00:00:00 2001 From: MuhHanif <48muhhanif@gmail.com> Date: Tue, 31 Jan 2023 09:25:38 +0700 Subject: [PATCH 02/17] test --- src/diffusers/models/attention_flax.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index f902c3f9af96..4b1ee7fa7d48 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -33,6 +33,7 @@ class FlaxAttentionBlock(nn.Module): Parameters `dtype` use_memory_efficient (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 + """ query_dim: int From 9347cc580c8d3f17674371eaad026c238ec81523 Mon Sep 17 00:00:00 2001 From: MuhHanif <48muhhanif@gmail.com> Date: Fri, 3 Feb 2023 15:28:01 +0700 Subject: [PATCH 03/17] add memory efficient attention jax --- src/diffusers/models/attention_flax.py | 46 ++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 4b1ee7fa7d48..6800b1b50f02 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -14,6 +14,7 @@ import flax.linen as nn import jax.numpy as jnp +from memory_efficient_attention_jax import memory_efficient_attention class FlaxAttentionBlock(nn.Module): @@ -33,7 +34,6 @@ class FlaxAttentionBlock(nn.Module): Parameters `dtype` use_memory_efficient (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 - """ query_dim: int @@ -81,13 +81,45 @@ def __call__(self, hidden_states, context=None, deterministic=True): key_states = self.reshape_heads_to_batch_dim(key_proj) value_states = self.reshape_heads_to_batch_dim(value_proj) - # compute attentions - attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) - attention_scores = attention_scores * self.scale - attention_probs = nn.softmax(attention_scores, axis=2) + if self.use_memory_efficient: + + query_states = query_states.transpose(1,0,2) + key_states =key_states.transpose(1,0,2) + value_states =value_states.transpose(1,0,2) + + #this if statement create a chunk size for each layer of the unet + #the chunk size is equal to the query_length dimension of the deepest layer of the unet + + flatten_latent_dim = query_states.shape[-3] + if flatten_latent_dim % 64 == 0: + query_chunk_size = int(flatten_latent_dim/64) + elif flatten_latent_dim % 16 == 0: + query_chunk_size = int(flatten_latent_dim/16) + elif flatten_latent_dim % 4 == 0: + query_chunk_size = int(flatten_latent_dim/4) + else: + query_chunk_size = int(flatten_latent_dim) + + hidden_states=memory_efficient_attention( + query_states, + key_states, + value_states, + query_chunk_size=query_chunk_size, + key_chunk_size=4096*4 + ) + + hidden_states=hidden_states.transpose(1,0,2) + + else: + + # compute attentions + attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) + attention_scores = attention_scores * self.scale + attention_probs = nn.softmax(attention_scores, axis=2) + + # attend to values + hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) - # attend to values - hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) hidden_states = self.reshape_batch_dim_to_heads(hidden_states) hidden_states = self.proj_attn(hidden_states) return hidden_states From eac25e0d3548ed7d150c3a2db5e2806de9d3d395 Mon Sep 17 00:00:00 2001 From: MuhHanif <48muhhanif@gmail.com> Date: Fri, 3 Feb 2023 15:29:26 +0700 Subject: [PATCH 04/17] add memory efficient attention jax --- .../models/memory_efficient_attention_jax.py | 114 ++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 src/diffusers/models/memory_efficient_attention_jax.py diff --git a/src/diffusers/models/memory_efficient_attention_jax.py b/src/diffusers/models/memory_efficient_attention_jax.py new file mode 100644 index 000000000000..7c94f69d85d4 --- /dev/null +++ b/src/diffusers/models/memory_efficient_attention_jax.py @@ -0,0 +1,114 @@ +import functools, jax, math +from jax import numpy as jnp +import numpy as np + +def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): + """Multi-head dot product attention with a limited number of queries.""" + num_kv, num_heads, k_features = key.shape[-3:] + v_features = value.shape[-1] + key_chunk_size = min(key_chunk_size, num_kv) + query = query / jnp.sqrt(k_features) + + @functools.partial(jax.checkpoint, prevent_cse=False) + def summarize_chunk(query, key, value): + attn_weights = jnp.einsum('...qhd,...khd->...qhk', query, key, precision=precision) + + max_score = jnp.max(attn_weights, axis=-1, keepdims=True) + max_score = jax.lax.stop_gradient(max_score) + exp_weights = jnp.exp(attn_weights - max_score) + + exp_values = jnp.einsum('...vhf,...qhv->...qhf', value, exp_weights, precision=precision) + max_score = jnp.einsum('...qhk->...qh', max_score) + + return (exp_values, exp_weights.sum(axis=-1), max_score) + + def chunk_scanner(chunk_idx): + #julienne key array + key_chunk = jax.lax.dynamic_slice( + operand=key, + start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], #[...,k,h,d] + slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features] #[...,k,h,d] + ) + + #julienne value array + value_chunk = jax.lax.dynamic_slice( + operand=value, + start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], #[...,v,h,d] + slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features] #[...,v,h,d] + ) + + return summarize_chunk(query, key_chunk, value_chunk) + + chunk_values, chunk_weights, chunk_max = jax.lax.map( + f=chunk_scanner, + xs=jnp.arange(0, num_kv, key_chunk_size) + ) + + global_max = jnp.max(chunk_max, axis=0, keepdims=True) + max_diffs = jnp.exp(chunk_max - global_max) + + chunk_values *= jnp.expand_dims(max_diffs, axis=-1) + chunk_weights *= max_diffs + + all_values = chunk_values.sum(axis=0) + all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0) + + return all_values / all_weights + +def memory_efficient_attention( + query, + key, + value, + precision = jax.lax.Precision.HIGHEST, + query_chunk_size: int = 1024, + key_chunk_size: int = 4096): + r""" + Flax Memory-efficient multi-head dot product attention. + https://arxiv.org/abs/2112.05682v2 + https://github.com/AminRezaei0x443/memory-efficient-attention + + Args: + query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head) + key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head) + value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head) + precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`): + numerical precision for computation + query_chunk_size (`int`, *optional*, defaults to 1024): + chunk size to divide query array + value must divide query_length equally without remainder + key_chunk_size (`int`, *optional*, defaults to 4096): + chunk size to divide key and value array + value must divide key_value_length equally without remainder + + Returns: + (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head) + """ + num_q, num_heads, q_features = query.shape[-3:] + + def chunk_scanner(chunk_idx, _): + #julienne query array + query_chunk = jax.lax.dynamic_slice( + operand=query, + start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], #[...,q,h,d] + slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features] #[...,q,h,d] + ) + + return( + chunk_idx + query_chunk_size, #unused ignore it + _query_chunk_attention( + query=query_chunk, + key=key, + value=value, + precision=precision, + key_chunk_size=4096 + ) + ) + + _, res = jax.lax.scan( + f=chunk_scanner, + init=0, #start counter + xs=None, + length=math.ceil(num_q / query_chunk_size) #stop counter + ) + + return jnp.concatenate(res, axis=-3) #fuse the chunked result back \ No newline at end of file From 99d88e632bf7a56fcd13d49b6c4dcc2f063b0638 Mon Sep 17 00:00:00 2001 From: MuhHanif <48muhhanif@gmail.com> Date: Fri, 3 Feb 2023 15:30:52 +0700 Subject: [PATCH 05/17] newline --- src/diffusers/models/memory_efficient_attention_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/memory_efficient_attention_jax.py b/src/diffusers/models/memory_efficient_attention_jax.py index 7c94f69d85d4..793c45b1cb1a 100644 --- a/src/diffusers/models/memory_efficient_attention_jax.py +++ b/src/diffusers/models/memory_efficient_attention_jax.py @@ -111,4 +111,4 @@ def chunk_scanner(chunk_idx, _): length=math.ceil(num_q / query_chunk_size) #stop counter ) - return jnp.concatenate(res, axis=-3) #fuse the chunked result back \ No newline at end of file + return jnp.concatenate(res, axis=-3) #fuse the chunked result back From 22557946b73742e7eb04f72a71a468a96f7cf32c Mon Sep 17 00:00:00 2001 From: MuhHanif <48muhhanif@gmail.com> Date: Fri, 3 Feb 2023 15:41:51 +0700 Subject: [PATCH 06/17] forgot dot --- src/diffusers/models/attention_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 6800b1b50f02..80cc10145c27 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -14,7 +14,7 @@ import flax.linen as nn import jax.numpy as jnp -from memory_efficient_attention_jax import memory_efficient_attention +from .memory_efficient_attention_jax import memory_efficient_attention class FlaxAttentionBlock(nn.Module): From 00803e94c627eca0704480ec363340efb49ac5f3 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 29 Mar 2023 10:18:21 +0000 Subject: [PATCH 07/17] Rename use_memory_efficient --- src/diffusers/models/attention_flax.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 80cc10145c27..4070a593573f 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -32,7 +32,7 @@ class FlaxAttentionBlock(nn.Module): Dropout rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` - use_memory_efficient (`bool`, *optional*, defaults to `False`): + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 """ @@ -41,7 +41,7 @@ class FlaxAttentionBlock(nn.Module): dim_head: int = 64 dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 - use_memory_efficient: bool = False + use_memory_efficient_attention: bool = False def setup(self): inner_dim = self.dim_head * self.heads @@ -81,8 +81,7 @@ def __call__(self, hidden_states, context=None, deterministic=True): key_states = self.reshape_heads_to_batch_dim(key_proj) value_states = self.reshape_heads_to_batch_dim(value_proj) - if self.use_memory_efficient: - + if self.use_memory_efficient_attention: query_states = query_states.transpose(1,0,2) key_states =key_states.transpose(1,0,2) value_states =value_states.transpose(1,0,2) @@ -109,9 +108,7 @@ def __call__(self, hidden_states, context=None, deterministic=True): ) hidden_states=hidden_states.transpose(1,0,2) - else: - # compute attentions attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) attention_scores = attention_scores * self.scale @@ -144,7 +141,7 @@ class FlaxBasicTransformerBlock(nn.Module): Whether to only apply cross attention. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` - use_memory_efficient (`bool`, *optional*, defaults to `False`): + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 """ dim: int @@ -153,7 +150,7 @@ class FlaxBasicTransformerBlock(nn.Module): dropout: float = 0.0 only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 - use_memory_efficient: bool = False + use_memory_efficient_attention: bool = False def setup(self): # self attention (or cross_attention if only_cross_attention is True) @@ -162,7 +159,7 @@ def setup(self): self.d_head, self.dropout, dtype=self.dtype, - use_memory_efficient=self.use_memory_efficient, + use_memory_efficient_attention=self.use_memory_efficient_attention, ) # cross attention self.attn2 = FlaxAttentionBlock(self.dim, @@ -170,7 +167,7 @@ def setup(self): self.d_head, self.dropout, dtype=self.dtype, - use_memory_efficient=self.use_memory_efficient, + use_memory_efficient_attention=self.use_memory_efficient_attention, ) self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) @@ -220,7 +217,7 @@ class FlaxTransformer2DModel(nn.Module): only_cross_attention (`bool`, defaults to `False`): tbd dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` - use_memory_efficient (`bool`, *optional*, defaults to `False`): + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 """ in_channels: int @@ -231,7 +228,7 @@ class FlaxTransformer2DModel(nn.Module): use_linear_projection: bool = False only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 - use_memory_efficient: bool = False + use_memory_efficient_attention: bool = False def setup(self): self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) @@ -256,7 +253,7 @@ def setup(self): dropout=self.dropout, only_cross_attention=self.only_cross_attention, dtype=self.dtype, - use_memory_efficient=self.use_memory_efficient, + use_memory_efficient_attention=self.use_memory_efficient_attention, ) for _ in range(self.depth) ] From 26ba0c4434a5c9376800436590413407df4c9407 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 29 Mar 2023 10:24:34 +0000 Subject: [PATCH 08/17] Keep dtype last. --- src/diffusers/models/attention_flax.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index e0098bb1110a..3c3f32995a9e 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -30,18 +30,18 @@ class FlaxAttention(nn.Module): Hidden states dimension inside each head dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` """ query_dim: int heads: int = 8 dim_head: int = 64 dropout: float = 0.0 - dtype: jnp.dtype = jnp.float32 use_memory_efficient_attention: bool = False + dtype: jnp.dtype = jnp.float32 def setup(self): inner_dim = self.dim_head * self.heads @@ -154,9 +154,9 @@ class FlaxBasicTransformerBlock(nn.Module): def setup(self): # self attention (or cross_attention if only_cross_attention is True) - self.attn1 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype, use_memory_efficient_attention=self.use_memory_efficient_attention) + self.attn1 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype) # cross attention - self.attn2 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype, use_memory_efficient_attention=self.use_memory_efficient_attention) + self.attn2 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype) self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) From 4b82e44ecf7ca74bb5515929002c2995fe162bad Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 29 Mar 2023 10:28:55 +0000 Subject: [PATCH 09/17] Actually use key_chunk_size --- .../models/memory_efficient_attention_jax.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/memory_efficient_attention_jax.py b/src/diffusers/models/memory_efficient_attention_jax.py index 793c45b1cb1a..a301b5a50552 100644 --- a/src/diffusers/models/memory_efficient_attention_jax.py +++ b/src/diffusers/models/memory_efficient_attention_jax.py @@ -1,6 +1,5 @@ import functools, jax, math from jax import numpy as jnp -import numpy as np def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): """Multi-head dot product attention with a limited number of queries.""" @@ -23,14 +22,14 @@ def summarize_chunk(query, key, value): return (exp_values, exp_weights.sum(axis=-1), max_score) def chunk_scanner(chunk_idx): - #julienne key array + # julienne key array key_chunk = jax.lax.dynamic_slice( operand=key, start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], #[...,k,h,d] slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features] #[...,k,h,d] ) - #julienne value array + # julienne value array value_chunk = jax.lax.dynamic_slice( operand=value, start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], #[...,v,h,d] @@ -86,7 +85,7 @@ def memory_efficient_attention( num_q, num_heads, q_features = query.shape[-3:] def chunk_scanner(chunk_idx, _): - #julienne query array + # julienne query array query_chunk = jax.lax.dynamic_slice( operand=query, start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], #[...,q,h,d] @@ -94,21 +93,21 @@ def chunk_scanner(chunk_idx, _): ) return( - chunk_idx + query_chunk_size, #unused ignore it + chunk_idx + query_chunk_size, # unused ignore it _query_chunk_attention( query=query_chunk, key=key, value=value, precision=precision, - key_chunk_size=4096 + key_chunk_size=key_chunk_size ) ) _, res = jax.lax.scan( f=chunk_scanner, - init=0, #start counter + init=0, # start counter xs=None, - length=math.ceil(num_q / query_chunk_size) #stop counter + length=math.ceil(num_q / query_chunk_size) # stop counter ) - return jnp.concatenate(res, axis=-3) #fuse the chunked result back + return jnp.concatenate(res, axis=-3) # fuse the chunked result back From 78a106e4fc3be0406765fd72d4a1dd3df45bd4ab Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 29 Mar 2023 10:31:54 +0000 Subject: [PATCH 10/17] Rename symbol --- src/diffusers/models/attention_flax.py | 4 ++-- src/diffusers/models/memory_efficient_attention_jax.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 3c3f32995a9e..7a4da1281405 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -14,7 +14,7 @@ import flax.linen as nn import jax.numpy as jnp -from .memory_efficient_attention_jax import memory_efficient_attention +from .memory_efficient_attention_jax import jax_memory_efficient_attention class FlaxAttention(nn.Module): @@ -99,7 +99,7 @@ def __call__(self, hidden_states, context=None, deterministic=True): else: query_chunk_size = int(flatten_latent_dim) - hidden_states=memory_efficient_attention( + hidden_states=jax_memory_efficient_attention( query_states, key_states, value_states, diff --git a/src/diffusers/models/memory_efficient_attention_jax.py b/src/diffusers/models/memory_efficient_attention_jax.py index a301b5a50552..f43f43ae791a 100644 --- a/src/diffusers/models/memory_efficient_attention_jax.py +++ b/src/diffusers/models/memory_efficient_attention_jax.py @@ -54,7 +54,7 @@ def chunk_scanner(chunk_idx): return all_values / all_weights -def memory_efficient_attention( +def jax_memory_efficient_attention( query, key, value, From 41ea7c28f9902ecbcb105b1cd44bf195b2d02ed6 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 29 Mar 2023 10:34:56 +0000 Subject: [PATCH 11/17] Apply style --- src/diffusers/models/attention_flax.py | 43 +++++---- .../models/memory_efficient_attention_jax.py | 96 ++++++++----------- src/diffusers/models/unet_2d_blocks_flax.py | 6 +- .../models/unet_2d_condition_flax.py | 2 +- 4 files changed, 68 insertions(+), 79 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 7a4da1281405..33c0b8840a30 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -14,6 +14,7 @@ import flax.linen as nn import jax.numpy as jnp + from .memory_efficient_attention_jax import jax_memory_efficient_attention @@ -82,32 +83,28 @@ def __call__(self, hidden_states, context=None, deterministic=True): value_states = self.reshape_heads_to_batch_dim(value_proj) if self.use_memory_efficient_attention: - query_states = query_states.transpose(1,0,2) - key_states =key_states.transpose(1,0,2) - value_states =value_states.transpose(1,0,2) - - #this if statement create a chunk size for each layer of the unet - #the chunk size is equal to the query_length dimension of the deepest layer of the unet - + query_states = query_states.transpose(1, 0, 2) + key_states = key_states.transpose(1, 0, 2) + value_states = value_states.transpose(1, 0, 2) + + # this if statement create a chunk size for each layer of the unet + # the chunk size is equal to the query_length dimension of the deepest layer of the unet + flatten_latent_dim = query_states.shape[-3] if flatten_latent_dim % 64 == 0: - query_chunk_size = int(flatten_latent_dim/64) + query_chunk_size = int(flatten_latent_dim / 64) elif flatten_latent_dim % 16 == 0: - query_chunk_size = int(flatten_latent_dim/16) + query_chunk_size = int(flatten_latent_dim / 16) elif flatten_latent_dim % 4 == 0: - query_chunk_size = int(flatten_latent_dim/4) + query_chunk_size = int(flatten_latent_dim / 4) else: query_chunk_size = int(flatten_latent_dim) - - hidden_states=jax_memory_efficient_attention( - query_states, - key_states, - value_states, - query_chunk_size=query_chunk_size, - key_chunk_size=4096*4 + + hidden_states = jax_memory_efficient_attention( + query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 ) - - hidden_states=hidden_states.transpose(1,0,2) + + hidden_states = hidden_states.transpose(1, 0, 2) else: # compute attentions attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) @@ -154,9 +151,13 @@ class FlaxBasicTransformerBlock(nn.Module): def setup(self): # self attention (or cross_attention if only_cross_attention is True) - self.attn1 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype) + self.attn1 = FlaxAttention( + self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype + ) # cross attention - self.attn2 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype) + self.attn2 = FlaxAttention( + self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype + ) self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) diff --git a/src/diffusers/models/memory_efficient_attention_jax.py b/src/diffusers/models/memory_efficient_attention_jax.py index f43f43ae791a..cd156c14e1c3 100644 --- a/src/diffusers/models/memory_efficient_attention_jax.py +++ b/src/diffusers/models/memory_efficient_attention_jax.py @@ -1,6 +1,10 @@ -import functools, jax, math +import functools +import math + +import jax from jax import numpy as jnp + def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): """Multi-head dot product attention with a limited number of queries.""" num_kv, num_heads, k_features = key.shape[-3:] @@ -10,74 +14,65 @@ def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4 @functools.partial(jax.checkpoint, prevent_cse=False) def summarize_chunk(query, key, value): - attn_weights = jnp.einsum('...qhd,...khd->...qhk', query, key, precision=precision) + attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision) max_score = jnp.max(attn_weights, axis=-1, keepdims=True) max_score = jax.lax.stop_gradient(max_score) exp_weights = jnp.exp(attn_weights - max_score) - exp_values = jnp.einsum('...vhf,...qhv->...qhf', value, exp_weights, precision=precision) - max_score = jnp.einsum('...qhk->...qh', max_score) + exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision) + max_score = jnp.einsum("...qhk->...qh", max_score) return (exp_values, exp_weights.sum(axis=-1), max_score) def chunk_scanner(chunk_idx): - # julienne key array - key_chunk = jax.lax.dynamic_slice( - operand=key, - start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], #[...,k,h,d] - slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features] #[...,k,h,d] + # julienne key array + key_chunk = jax.lax.dynamic_slice( + operand=key, + start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d] + slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d] ) - # julienne value array + # julienne value array value_chunk = jax.lax.dynamic_slice( - operand=value, - start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], #[...,v,h,d] - slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features] #[...,v,h,d] + operand=value, + start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d] + slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d] ) return summarize_chunk(query, key_chunk, value_chunk) - chunk_values, chunk_weights, chunk_max = jax.lax.map( - f=chunk_scanner, - xs=jnp.arange(0, num_kv, key_chunk_size) - ) - + chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) + global_max = jnp.max(chunk_max, axis=0, keepdims=True) max_diffs = jnp.exp(chunk_max - global_max) chunk_values *= jnp.expand_dims(max_diffs, axis=-1) chunk_weights *= max_diffs - + all_values = chunk_values.sum(axis=0) all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0) return all_values / all_weights + def jax_memory_efficient_attention( - query, - key, - value, - precision = jax.lax.Precision.HIGHEST, - query_chunk_size: int = 1024, - key_chunk_size: int = 4096): + query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096 +): r""" - Flax Memory-efficient multi-head dot product attention. - https://arxiv.org/abs/2112.05682v2 + Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2 https://github.com/AminRezaei0x443/memory-efficient-attention - + Args: query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head) key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head) value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head) - precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`): + precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`): numerical precision for computation - query_chunk_size (`int`, *optional*, defaults to 1024): - chunk size to divide query array - value must divide query_length equally without remainder - key_chunk_size (`int`, *optional*, defaults to 4096): - chunk size to divide key and value array - value must divide key_value_length equally without remainder + query_chunk_size (`int`, *optional*, defaults to 1024): + chunk size to divide query array value must divide query_length equally without remainder + key_chunk_size (`int`, *optional*, defaults to 4096): + chunk size to divide key and value array value must divide key_value_length equally without remainder Returns: (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head) @@ -85,29 +80,22 @@ def jax_memory_efficient_attention( num_q, num_heads, q_features = query.shape[-3:] def chunk_scanner(chunk_idx, _): - # julienne query array + # julienne query array query_chunk = jax.lax.dynamic_slice( - operand=query, - start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], #[...,q,h,d] - slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features] #[...,q,h,d] + operand=query, + start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d] + slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d] ) - return( - chunk_idx + query_chunk_size, # unused ignore it - _query_chunk_attention( - query=query_chunk, - key=key, - value=value, - precision=precision, - key_chunk_size=key_chunk_size - ) + return ( + chunk_idx + query_chunk_size, # unused ignore it + _query_chunk_attention( + query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size + ), ) - _, res = jax.lax.scan( - f=chunk_scanner, - init=0, # start counter - xs=None, - length=math.ceil(num_q / query_chunk_size) # stop counter + _, res = jax.lax.scan( + f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter ) - return jnp.concatenate(res, axis=-3) # fuse the chunked result back + return jnp.concatenate(res, axis=-3) # fuse the chunked result back diff --git a/src/diffusers/models/unet_2d_blocks_flax.py b/src/diffusers/models/unet_2d_blocks_flax.py index 735728061d40..c6252a7005f9 100644 --- a/src/diffusers/models/unet_2d_blocks_flax.py +++ b/src/diffusers/models/unet_2d_blocks_flax.py @@ -51,7 +51,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): use_linear_projection: bool = False only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 - use_memory_efficient:bool = False + use_memory_efficient: bool = False def setup(self): resnets = [] @@ -191,7 +191,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): use_linear_projection: bool = False only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 - use_memory_efficient:bool = False + use_memory_efficient: bool = False def setup(self): resnets = [] @@ -330,7 +330,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): attn_num_head_channels: int = 1 use_linear_projection: bool = False dtype: jnp.dtype = jnp.float32 - use_memory_efficient:bool = False + use_memory_efficient: bool = False def setup(self): # there is always at least one resnet diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index dbcfd3ff1340..ff19a9dadbf6 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -113,7 +113,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): dtype: jnp.dtype = jnp.float32 flip_sin_to_cos: bool = True freq_shift: int = 0 - use_memory_efficient:bool = False + use_memory_efficient: bool = False def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: # init input tensors From f32f33131eac9f52dbf599a676028d9357bcaef2 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 29 Mar 2023 10:53:03 +0000 Subject: [PATCH 12/17] Rename use_memory_efficient --- src/diffusers/models/unet_2d_blocks_flax.py | 20 +++++++++---------- .../models/unet_2d_condition_flax.py | 10 +++++----- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks_flax.py b/src/diffusers/models/unet_2d_blocks_flax.py index c6252a7005f9..0e6ef3237d7d 100644 --- a/src/diffusers/models/unet_2d_blocks_flax.py +++ b/src/diffusers/models/unet_2d_blocks_flax.py @@ -37,10 +37,10 @@ class FlaxCrossAttnDownBlock2D(nn.Module): Number of attention heads of each spatial transformer block add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` - use_memory_efficient (`bool`, *optional*, defaults to `False`): - enable memory efficient attention https://arxiv.org/abs/2112.05682 """ in_channels: int out_channels: int @@ -50,8 +50,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module): add_downsample: bool = True use_linear_projection: bool = False only_cross_attention: bool = False + use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 - use_memory_efficient: bool = False def setup(self): resnets = [] @@ -76,7 +76,7 @@ def setup(self): use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, dtype=self.dtype, - use_memory_efficient=self.use_memory_efficient, + use_memory_efficient_attention=self.use_memory_efficient_attention, ) attentions.append(attn_block) @@ -178,7 +178,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): Whether to add upsampling layer before each final output dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` - use_memory_efficient (`bool`, *optional*, defaults to `False`): + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 """ in_channels: int @@ -191,7 +191,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): use_linear_projection: bool = False only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 - use_memory_efficient: bool = False + use_memory_efficient_attention: bool = False def setup(self): resnets = [] @@ -217,7 +217,7 @@ def setup(self): use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, dtype=self.dtype, - use_memory_efficient=self.use_memory_efficient, + use_memory_efficient_attention=self.use_memory_efficient_attention, ) attentions.append(attn_block) @@ -321,7 +321,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): Number of attention heads of each spatial transformer block dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` - use_memory_efficient (`bool`, *optional*, defaults to `False`): + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 """ in_channels: int @@ -330,7 +330,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): attn_num_head_channels: int = 1 use_linear_projection: bool = False dtype: jnp.dtype = jnp.float32 - use_memory_efficient: bool = False + use_memory_efficient_attention: bool = False def setup(self): # there is always at least one resnet @@ -353,7 +353,7 @@ def setup(self): depth=1, use_linear_projection=self.use_linear_projection, dtype=self.dtype, - use_memory_efficient=self.use_memory_efficient, + use_memory_efficient_attention=self.use_memory_efficient_attention, ) attentions.append(attn_block) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index ff19a9dadbf6..c8c0df6f144b 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -88,7 +88,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. - use_memory_efficient (`bool`, *optional*, defaults to `False`): + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 """ @@ -113,7 +113,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): dtype: jnp.dtype = jnp.float32 flip_sin_to_cos: bool = True freq_shift: int = 0 - use_memory_efficient: bool = False + use_memory_efficient_attention: bool = False def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: # init input tensors @@ -173,7 +173,7 @@ def setup(self): use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], dtype=self.dtype, - use_memory_efficient=self.use_memory_efficient, + use_memory_efficient_attention=self.use_memory_efficient_attention, ) else: down_block = FlaxDownBlock2D( @@ -195,7 +195,7 @@ def setup(self): attn_num_head_channels=attention_head_dim[-1], use_linear_projection=self.use_linear_projection, dtype=self.dtype, - use_memory_efficient=self.use_memory_efficient, + use_memory_efficient_attention=self.use_memory_efficient_attention, ) # up @@ -223,7 +223,7 @@ def setup(self): use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], dtype=self.dtype, - use_memory_efficient=self.use_memory_efficient, + use_memory_efficient_attention=self.use_memory_efficient_attention, ) else: up_block = FlaxUpBlock2D( From 00be593f7327525055b5ac21abdb821c12a511d7 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 29 Mar 2023 10:55:45 +0000 Subject: [PATCH 13/17] Keep dtype last --- src/diffusers/models/unet_2d_blocks_flax.py | 18 +++++++++--------- src/diffusers/models/unet_2d_condition_flax.py | 6 +++--- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks_flax.py b/src/diffusers/models/unet_2d_blocks_flax.py index 0e6ef3237d7d..b8126c5f5930 100644 --- a/src/diffusers/models/unet_2d_blocks_flax.py +++ b/src/diffusers/models/unet_2d_blocks_flax.py @@ -75,8 +75,8 @@ def setup(self): depth=1, use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, - dtype=self.dtype, use_memory_efficient_attention=self.use_memory_efficient_attention, + dtype=self.dtype, ) attentions.append(attn_block) @@ -176,10 +176,10 @@ class FlaxCrossAttnUpBlock2D(nn.Module): Number of attention heads of each spatial transformer block add_upsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add upsampling layer before each final output - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` """ in_channels: int out_channels: int @@ -190,8 +190,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module): add_upsample: bool = True use_linear_projection: bool = False only_cross_attention: bool = False - dtype: jnp.dtype = jnp.float32 use_memory_efficient_attention: bool = False + dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] @@ -216,8 +216,8 @@ def setup(self): depth=1, use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, - dtype=self.dtype, use_memory_efficient_attention=self.use_memory_efficient_attention, + dtype=self.dtype, ) attentions.append(attn_block) @@ -319,18 +319,18 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): Number of attention blocks layers attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` """ in_channels: int dropout: float = 0.0 num_layers: int = 1 attn_num_head_channels: int = 1 use_linear_projection: bool = False - dtype: jnp.dtype = jnp.float32 use_memory_efficient_attention: bool = False + dtype: jnp.dtype = jnp.float32 def setup(self): # there is always at least one resnet @@ -352,8 +352,8 @@ def setup(self): d_head=self.in_channels // self.attn_num_head_channels, depth=1, use_linear_projection=self.use_linear_projection, - dtype=self.dtype, use_memory_efficient_attention=self.use_memory_efficient_attention, + dtype=self.dtype, ) attentions.append(attn_block) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index c8c0df6f144b..3c2f4a88ab7f 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -172,8 +172,8 @@ def setup(self): add_downsample=not is_final_block, use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], - dtype=self.dtype, use_memory_efficient_attention=self.use_memory_efficient_attention, + dtype=self.dtype, ) else: down_block = FlaxDownBlock2D( @@ -194,8 +194,8 @@ def setup(self): dropout=self.dropout, attn_num_head_channels=attention_head_dim[-1], use_linear_projection=self.use_linear_projection, - dtype=self.dtype, use_memory_efficient_attention=self.use_memory_efficient_attention, + dtype=self.dtype, ) # up @@ -222,8 +222,8 @@ def setup(self): dropout=self.dropout, use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], - dtype=self.dtype, use_memory_efficient_attention=self.use_memory_efficient_attention, + dtype=self.dtype, ) else: up_block = FlaxUpBlock2D( From b5544f23ba1678fa4e39e837b6df086ab9a66443 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 29 Mar 2023 12:29:18 +0000 Subject: [PATCH 14/17] Pass `use_memory_efficient_attention` in `from_pretrained` --- src/diffusers/pipelines/pipeline_flax_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py index 9d91ff757799..6ab0b80ee655 100644 --- a/src/diffusers/pipelines/pipeline_flax_utils.py +++ b/src/diffusers/pipelines/pipeline_flax_utils.py @@ -296,6 +296,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) from_pt = kwargs.pop("from_pt", False) + use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False) dtype = kwargs.pop("dtype", None) # 1. Download the checkpoints and configs @@ -451,7 +452,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P loaded_sub_model = cached_folder if issubclass(class_obj, FlaxModelMixin): - loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype) + loaded_sub_model, loaded_params = load_method( + loadable_folder, + from_pt=from_pt, + use_memory_efficient_attention=use_memory_efficient_attention, + dtype=dtype, + ) params[name] = loaded_params elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): if from_pt: From 8d35f092a9cf841fd7343adc5e97aa126850cfce Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 10 Apr 2023 19:21:30 +0000 Subject: [PATCH 15/17] Move JAX memory efficient attention to attention_flax. --- src/diffusers/models/attention_flax.py | 101 +++++++++++++++++- .../models/memory_efficient_attention_jax.py | 101 ------------------ 2 files changed, 99 insertions(+), 103 deletions(-) delete mode 100644 src/diffusers/models/memory_efficient_attention_jax.py diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 33c0b8840a30..e900e39724fc 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -12,10 +12,107 @@ # See the License for the specific language governing permissions and # limitations under the License. -import flax.linen as nn +import functools +import math + +import jax import jax.numpy as jnp +import flax.linen as nn + +def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): + """Multi-head dot product attention with a limited number of queries.""" + num_kv, num_heads, k_features = key.shape[-3:] + v_features = value.shape[-1] + key_chunk_size = min(key_chunk_size, num_kv) + query = query / jnp.sqrt(k_features) + + @functools.partial(jax.checkpoint, prevent_cse=False) + def summarize_chunk(query, key, value): + attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision) + + max_score = jnp.max(attn_weights, axis=-1, keepdims=True) + max_score = jax.lax.stop_gradient(max_score) + exp_weights = jnp.exp(attn_weights - max_score) + + exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision) + max_score = jnp.einsum("...qhk->...qh", max_score) + + return (exp_values, exp_weights.sum(axis=-1), max_score) + + def chunk_scanner(chunk_idx): + # julienne key array + key_chunk = jax.lax.dynamic_slice( + operand=key, + start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d] + slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d] + ) + + # julienne value array + value_chunk = jax.lax.dynamic_slice( + operand=value, + start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d] + slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d] + ) + + return summarize_chunk(query, key_chunk, value_chunk) + + chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) + + global_max = jnp.max(chunk_max, axis=0, keepdims=True) + max_diffs = jnp.exp(chunk_max - global_max) + + chunk_values *= jnp.expand_dims(max_diffs, axis=-1) + chunk_weights *= max_diffs + + all_values = chunk_values.sum(axis=0) + all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0) + + return all_values / all_weights + + +def jax_memory_efficient_attention( + query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096 +): + r""" + Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2 + https://github.com/AminRezaei0x443/memory-efficient-attention + + Args: + query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head) + key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head) + value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head) + precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`): + numerical precision for computation + query_chunk_size (`int`, *optional*, defaults to 1024): + chunk size to divide query array value must divide query_length equally without remainder + key_chunk_size (`int`, *optional*, defaults to 4096): + chunk size to divide key and value array value must divide key_value_length equally without remainder + + Returns: + (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head) + """ + num_q, num_heads, q_features = query.shape[-3:] + + def chunk_scanner(chunk_idx, _): + # julienne query array + query_chunk = jax.lax.dynamic_slice( + operand=query, + start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d] + slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d] + ) + + return ( + chunk_idx + query_chunk_size, # unused ignore it + _query_chunk_attention( + query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size + ), + ) + + _, res = jax.lax.scan( + f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter + ) -from .memory_efficient_attention_jax import jax_memory_efficient_attention + return jnp.concatenate(res, axis=-3) # fuse the chunked result back class FlaxAttention(nn.Module): diff --git a/src/diffusers/models/memory_efficient_attention_jax.py b/src/diffusers/models/memory_efficient_attention_jax.py deleted file mode 100644 index cd156c14e1c3..000000000000 --- a/src/diffusers/models/memory_efficient_attention_jax.py +++ /dev/null @@ -1,101 +0,0 @@ -import functools -import math - -import jax -from jax import numpy as jnp - - -def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): - """Multi-head dot product attention with a limited number of queries.""" - num_kv, num_heads, k_features = key.shape[-3:] - v_features = value.shape[-1] - key_chunk_size = min(key_chunk_size, num_kv) - query = query / jnp.sqrt(k_features) - - @functools.partial(jax.checkpoint, prevent_cse=False) - def summarize_chunk(query, key, value): - attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision) - - max_score = jnp.max(attn_weights, axis=-1, keepdims=True) - max_score = jax.lax.stop_gradient(max_score) - exp_weights = jnp.exp(attn_weights - max_score) - - exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision) - max_score = jnp.einsum("...qhk->...qh", max_score) - - return (exp_values, exp_weights.sum(axis=-1), max_score) - - def chunk_scanner(chunk_idx): - # julienne key array - key_chunk = jax.lax.dynamic_slice( - operand=key, - start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d] - slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d] - ) - - # julienne value array - value_chunk = jax.lax.dynamic_slice( - operand=value, - start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d] - slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d] - ) - - return summarize_chunk(query, key_chunk, value_chunk) - - chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) - - global_max = jnp.max(chunk_max, axis=0, keepdims=True) - max_diffs = jnp.exp(chunk_max - global_max) - - chunk_values *= jnp.expand_dims(max_diffs, axis=-1) - chunk_weights *= max_diffs - - all_values = chunk_values.sum(axis=0) - all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0) - - return all_values / all_weights - - -def jax_memory_efficient_attention( - query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096 -): - r""" - Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2 - https://github.com/AminRezaei0x443/memory-efficient-attention - - Args: - query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head) - key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head) - value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head) - precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`): - numerical precision for computation - query_chunk_size (`int`, *optional*, defaults to 1024): - chunk size to divide query array value must divide query_length equally without remainder - key_chunk_size (`int`, *optional*, defaults to 4096): - chunk size to divide key and value array value must divide key_value_length equally without remainder - - Returns: - (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head) - """ - num_q, num_heads, q_features = query.shape[-3:] - - def chunk_scanner(chunk_idx, _): - # julienne query array - query_chunk = jax.lax.dynamic_slice( - operand=query, - start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d] - slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d] - ) - - return ( - chunk_idx + query_chunk_size, # unused ignore it - _query_chunk_attention( - query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size - ), - ) - - _, res = jax.lax.scan( - f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter - ) - - return jnp.concatenate(res, axis=-3) # fuse the chunked result back From 952293c11accbdbf2924dfd121723b18765cf453 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 11 Apr 2023 11:22:29 +0000 Subject: [PATCH 16/17] Simple test. --- tests/test_pipelines_flax.py | 44 ++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/test_pipelines_flax.py b/tests/test_pipelines_flax.py index a461930f3a83..d4de65948e36 100644 --- a/tests/test_pipelines_flax.py +++ b/tests/test_pipelines_flax.py @@ -224,3 +224,47 @@ def test_stable_diffusion_v1_4_bfloat_16_ddim(self): if jax.device_count() == 8: assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 1e-3 assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1 + + def test_jax_memory_efficient_attention(self): + prompt = ( + "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of" + " field, close up, split lighting, cinematic" + ) + + num_samples = jax.device_count() + prompt = num_samples * [prompt] + prng_seed = jax.random.split(jax.random.PRNGKey(0), num_samples) + + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="bf16", + dtype=jnp.bfloat16, + safety_checker=None, + ) + + params = replicate(params) + prompt_ids = pipeline.prepare_inputs(prompt) + prompt_ids = shard(prompt_ids) + images = pipeline(prompt_ids, params, prng_seed, jit=True).images + assert images.shape == (num_samples, 1, 512, 512, 3) + slice = images[2, 0, 256, 10:17, 1] + + # With memory efficient attention + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="bf16", + dtype=jnp.bfloat16, + safety_checker=None, + use_memory_efficient_attention=True, + ) + + params = replicate(params) + prompt_ids = pipeline.prepare_inputs(prompt) + prompt_ids = shard(prompt_ids) + images_eff = pipeline(prompt_ids, params, prng_seed, jit=True).images + assert images_eff.shape == (num_samples, 1, 512, 512, 3) + slice_eff = images[2, 0, 256, 10:17, 1] + + # I checked the results visually and they are very similar. However, I saw that the max diff is `1` and the `sum` + # over the 8 images is exactly `256`, which is very suspicious. Testing a random slice for now. + assert abs(slice_eff - slice).max() < 1e-2 \ No newline at end of file From d1490bcd12b7e90a28d5bcad19d5dc8eff4b93f7 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 11 Apr 2023 11:23:01 +0000 Subject: [PATCH 17/17] style --- src/diffusers/models/attention_flax.py | 3 ++- tests/test_pipelines_flax.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index e900e39724fc..4f78b324a8e2 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -15,9 +15,10 @@ import functools import math +import flax.linen as nn import jax import jax.numpy as jnp -import flax.linen as nn + def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): """Multi-head dot product attention with a limited number of queries.""" diff --git a/tests/test_pipelines_flax.py b/tests/test_pipelines_flax.py index d4de65948e36..33f3aa671b3e 100644 --- a/tests/test_pipelines_flax.py +++ b/tests/test_pipelines_flax.py @@ -267,4 +267,4 @@ def test_jax_memory_efficient_attention(self): # I checked the results visually and they are very similar. However, I saw that the max diff is `1` and the `sum` # over the 8 images is exactly `256`, which is very suspicious. Testing a random slice for now. - assert abs(slice_eff - slice).max() < 1e-2 \ No newline at end of file + assert abs(slice_eff - slice).max() < 1e-2