From a467aa506ac81914da6be8d255a3de299f36ccea Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Fri, 16 Dec 2022 19:37:09 +0000 Subject: [PATCH 01/10] Add missing scale Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/diffusion_model_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 7f652961..7aabfbba 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -138,7 +138,7 @@ def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: return x def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - attention_scores = torch.matmul(query, key.transpose(-1, -2)) + attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale attention_probs = attention_scores.softmax(dim=-1) # compute attention output hidden_states = torch.matmul(attention_probs, value) From 6d0852b43dbbaf9bcf047b624c1a84a6b796b759 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Fri, 16 Dec 2022 21:34:58 +0000 Subject: [PATCH 02/10] [WIP] Add efficient attention to DiffusionModelUNet Signed-off-by: Walter Hugo Lopez Pinaya --- .../networks/nets/diffusion_model_unet.py | 47 +++++++++++++------ requirements.txt | 1 + 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 7aabfbba..869163d6 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -34,10 +34,17 @@ import torch import torch.nn.functional as F +import xformers.ops from monai.networks.blocks import Convolution from monai.networks.layers.factories import Pool from torch import nn +has_xformers = True + +# TODO: Make optional import work +# from monai.utils import optional_import +# xformers, has_xformers = optional_import("xformers.ops", name="xformers") + __all__ = ["DiffusionModelUNet"] @@ -114,8 +121,8 @@ def __init__( inner_dim = num_head_channels * num_attention_heads cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.scale = num_head_channels**-0.5 - self.heads = num_attention_heads + self.scale = math.sqrt(num_head_channels) + self.num_heads = num_attention_heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) @@ -125,27 +132,30 @@ def __init__( def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: batch_size, seq_len, dim = x.shape - head_size = self.heads - x = x.reshape(batch_size, seq_len, head_size, dim // head_size) - x = x.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) return x def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: batch_size, seq_len, dim = x.shape - head_size = self.heads - x = x.reshape(batch_size // head_size, head_size, seq_len, dim) - x = x.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) return x def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale attention_probs = attention_scores.softmax(dim=-1) - # compute attention output - hidden_states = torch.matmul(attention_probs, value) - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states + x = torch.matmul(attention_probs, value) + return x def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: query = self.to_q(x) @@ -153,11 +163,18 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> to key = self.to_k(context) value = self.to_v(context) + # Multi-Head Attention query = self.reshape_heads_to_batch_dim(query) key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) - x = self._attention(query, key, value) + if has_xformers: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) return self.to_out(x) diff --git a/requirements.txt b/requirements.txt index 88b7af44..6561ae05 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ monai-weekly==1.1.dev2248 numpy>=1.17 torch>=1.8 tqdm +xformers==0.0.16 From b7eb889c9a5dd3908ee8f28a1dfd19b822322cef Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Fri, 16 Dec 2022 21:59:09 +0000 Subject: [PATCH 03/10] [WIP] Add baddbmm-based attention to DiffusionModelUNet Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/diffusion_model_unet.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 869163d6..e5decea3 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -39,7 +39,7 @@ from monai.networks.layers.factories import Pool from torch import nn -has_xformers = True +has_xformers = False # TODO: Make optional import work # from monai.utils import optional_import @@ -152,9 +152,15 @@ def _memory_efficient_attention_xformers( return x def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) attention_probs = attention_scores.softmax(dim=-1) - x = torch.matmul(attention_probs, value) + x = torch.bmm(attention_probs, value) return x def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: From e289ec93602386190e627181b22aad1b5d8e1081 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Fri, 16 Dec 2022 22:45:36 +0000 Subject: [PATCH 04/10] Add efficient attentions to AutoencoderKL Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/autoencoderkl.py | 120 ++++++++++++++++------ 1 file changed, 88 insertions(+), 32 deletions(-) diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index f1a75636..ad028194 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -9,13 +9,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import Optional, Sequence, Tuple import torch import torch.nn as nn import torch.nn.functional as F +import xformers.ops from monai.networks.blocks import Convolution +has_xformers = True + +# TODO: Make optional import work +# from monai.utils import optional_import +# xformers, has_xformers = optional_import("xformers.ops", name="xformers") + __all__ = ["AutoencoderKL"] @@ -177,8 +185,11 @@ def __init__( self.spatial_dims = spatial_dims self.in_channels = in_channels + self.num_heads = 1 + self.scale = 1 / math.sqrt(in_channels / self.num_heads) + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) - self.q = Convolution( + self.to_q = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, @@ -187,7 +198,7 @@ def __init__( padding=0, conv_only=True, ) - self.k = Convolution( + self.to_k = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, @@ -196,7 +207,7 @@ def __init__( padding=0, conv_only=True, ) - self.v = Convolution( + self.to_v = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, @@ -215,45 +226,90 @@ def __init__( conv_only=True, ) + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + x = torch.bmm(attention_probs, value) + return x + def forward(self, x: torch.Tensor) -> torch.Tensor: - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # Compute attention - b = q.shape[0] - c = q.shape[1] - h = q.shape[2] - w = q.shape[3] - # in order to Torchscript work, we initialise d = 1 - d = 1 + residual = x + + # Norm + x = self.norm(x) + + # Project to query, key, and value vectors + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + # Reshape to B, C, SPATIAL_DIMS + batch = query.shape[0] + channel = query.shape[1] + height = query.shape[2] + width = query.shape[3] + # Note: in order to make Torchscript's tests work, we initialise depth = 1 + depth = 1 if self.spatial_dims == 3: - d = q.shape[4] - n_spatial_elements = h * w * d + depth = query.shape[4] - q = q.reshape(b, c, n_spatial_elements) - q = q.permute(0, 2, 1) - k = k.reshape(b, c, n_spatial_elements) - w_ = torch.bmm(q, k) - w_ = w_ * (int(c) ** (-0.5)) - w_ = F.softmax(w_, dim=2) + n_spatial_elements = height * width * depth + + query = query.reshape(batch, channel, n_spatial_elements) + key = key.reshape(batch, channel, n_spatial_elements) + value = value.reshape(batch, channel, n_spatial_elements) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if has_xformers: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) - # Attend to values - v = v.reshape(b, c, n_spatial_elements) - w_ = w_.permute(0, 2, 1) - h_ = torch.bmm(v, w_) + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + # Reshape to B, C, H, W [, D] if self.spatial_dims == 2: - h_ = h_.reshape(b, c, h, w) + x = x.reshape(batch, channel, height, width) if self.spatial_dims == 3: - h_ = h_.reshape(b, c, h, w, d) + x = x.reshape(batch, channel, height, width, depth) - h_ = self.proj_out(h_) + # Proj out + x = self.proj_out(x) - return x + h_ + return x + residual class Encoder(nn.Module): From f6ffa40510e143760a2bd3f3724fb67632b815d6 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Fri, 23 Dec 2022 15:18:05 +0000 Subject: [PATCH 05/10] Add code to check if xformers is available (#145) Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/diffusion_model_unet.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index e5decea3..fee6adae 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -29,6 +29,7 @@ # limitations under the License. # ========================================================================= +import importlib.util import math from typing import List, Optional, Sequence, Tuple @@ -39,9 +40,17 @@ from monai.networks.layers.factories import Pool from torch import nn -has_xformers = False +if importlib.util.find_spec("xformers") is not None: + import xformers + import xformers.ops -# TODO: Make optional import work + has_xformers = True +else: + xformers = None + has_xformers = False + + +# TODO: Use MONAI's optional_import # from monai.utils import optional_import # xformers, has_xformers = optional_import("xformers.ops", name="xformers") From 842cdf7226bf5336e54d52eaa9f77cb9f8fe7d0a Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Fri, 23 Dec 2022 18:18:56 +0000 Subject: [PATCH 06/10] Refactor attention layers (#145) Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/autoencoderkl.py | 150 +++++++++--------- .../networks/nets/diffusion_model_unet.py | 89 +++++++---- 2 files changed, 129 insertions(+), 110 deletions(-) diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index ad028194..9114293a 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib.util import math from typing import Optional, Sequence, Tuple @@ -18,7 +19,14 @@ import xformers.ops from monai.networks.blocks import Convolution -has_xformers = True +if importlib.util.find_spec("xformers") is not None: + import xformers + import xformers.ops + + has_xformers = True +else: + xformers = None + has_xformers = False # TODO: Make optional import work # from monai.utils import optional_import @@ -162,69 +170,41 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x + h -class AttnBlock(nn.Module): +class AttentionBlock(nn.Module): """ Attention block. Args: spatial_dims: number of spatial dimensions (1D, 2D, 3D). - in_channels: number of input channels. + num_channels: number of input channels. + num_head_channels: number of channels in each attention head. norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of channels is divisible by this number. - norm_eps: epsilon for the normalisation. + norm_eps: epsilon value to use for the normalisation. """ def __init__( self, spatial_dims: int, - in_channels: int, - norm_num_groups: int, - norm_eps: float, + num_channels: int, + num_head_channels: Optional[int] = None, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, ) -> None: super().__init__() self.spatial_dims = spatial_dims - self.in_channels = in_channels + self.num_channels = num_channels - self.num_heads = 1 - self.scale = 1 / math.sqrt(in_channels / self.num_heads) + self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 + self.scale = 1 / math.sqrt(num_channels / self.num_heads) - self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) - self.to_q = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - self.to_k = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - self.to_v = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - self.proj_out = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) + + self.to_q = nn.Linear(num_channels, num_channels) + self.to_k = nn.Linear(num_channels, num_channels) + self.to_v = nn.Linear(num_channels, num_channels) + + self.proj_attn = nn.Linear(num_channels, num_channels) def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: batch_size, seq_len, dim = x.shape @@ -262,31 +242,25 @@ def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x - # Norm + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + # norm x = self.norm(x) - # Project to query, key, and value vectors + if self.spatial_dims == 2: + x = x.view(batch, channel, height * width).transpose(1, 2) + if self.spatial_dims == 3: + x = x.view(batch, channel, height * width * depth).transpose(1, 2) + + # proj to q, k, v query = self.to_q(x) key = self.to_k(x) value = self.to_v(x) - # Reshape to B, C, SPATIAL_DIMS - batch = query.shape[0] - channel = query.shape[1] - height = query.shape[2] - width = query.shape[3] - - # Note: in order to make Torchscript's tests work, we initialise depth = 1 - depth = 1 - if self.spatial_dims == 3: - depth = query.shape[4] - - n_spatial_elements = height * width * depth - - query = query.reshape(batch, channel, n_spatial_elements) - key = key.reshape(batch, channel, n_spatial_elements) - value = value.reshape(batch, channel, n_spatial_elements) - # Multi-Head Attention query = self.reshape_heads_to_batch_dim(query) key = self.reshape_heads_to_batch_dim(key) @@ -300,14 +274,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.reshape_batch_dim_to_heads(x) x = x.to(query.dtype) - # Reshape to B, C, H, W [, D] if self.spatial_dims == 2: - x = x.reshape(batch, channel, height, width) + x = x.transpose(-1, -2).reshape(batch, channel, height, width) if self.spatial_dims == 3: - x = x.reshape(batch, channel, height, width, depth) - - # Proj out - x = self.proj_out(x) + x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) return x + residual @@ -391,7 +361,14 @@ def __init__( ) block_in_ch = block_out_ch if attention_levels[i]: - blocks.append(AttnBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps)) + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) if i != len(ch_mult) - 1: blocks.append(Downsample(spatial_dims, block_in_ch)) @@ -399,7 +376,14 @@ def __init__( # Non-local attention block if with_nonlocal_attn is True: blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch)) - blocks.append(AttnBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps)) + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch)) # Normalise and convert to latent size @@ -490,7 +474,14 @@ def __init__( # Non-local attention block if with_nonlocal_attn is True: blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch)) - blocks.append(AttnBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps)) + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch)) for i in reversed(range(len(ch_mult))): @@ -501,7 +492,14 @@ def __init__( block_in_ch = block_out_ch if attention_levels[i]: - blocks.append(AttnBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps)) + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) if i != 0: blocks.append(Upsample(spatial_dims, block_in_ch)) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index fee6adae..0c67c68b 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -130,7 +130,7 @@ def __init__( inner_dim = num_head_channels * num_attention_heads cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.scale = math.sqrt(num_head_channels) + self.scale = 1 / math.sqrt(num_head_channels) self.num_heads = num_attention_heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) @@ -354,10 +354,11 @@ class AttentionBlock(nn.Module): Args: spatial_dims: number of spatial dimensions. - num_channels: number of channels in the input and output. + num_channels: number of input channels. num_head_channels: number of channels in each attention head. - norm_num_groups: number of groups to use for group norm. - norm_eps: epsilon value to use for group norm. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon value to use for the normalisation. """ def __init__( @@ -373,21 +374,48 @@ def __init__( self.num_channels = num_channels self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 - self.num_head_size = num_head_channels + self.scale = 1 / math.sqrt(num_channels / self.num_heads) + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) - # define q,k,v as linear layers - self.query = nn.Linear(num_channels, num_channels) - self.key = nn.Linear(num_channels, num_channels) - self.value = nn.Linear(num_channels, num_channels) + self.to_q = nn.Linear(num_channels, num_channels) + self.to_k = nn.Linear(num_channels, num_channels) + self.to_v = nn.Linear(num_channels, num_channels) self.proj_attn = nn.Linear(num_channels, num_channels) - def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: - new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) - # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) - new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) - return new_projection + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + x = torch.bmm(attention_probs, value) + return x def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x @@ -407,29 +435,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.view(batch, channel, height * width * depth).transpose(1, 2) # proj to q, k, v - query_proj = self.query(x) - key_proj = self.key(x) - value_proj = self.value(x) - - # transpose - query_states = self.transpose_for_scores(query_proj) - key_states = self.transpose_for_scores(key_proj) - value_states = self.transpose_for_scores(value_proj) - - # get scores - scale = 1 / math.sqrt(math.sqrt(self.num_channels / self.num_heads)) - attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) - attention_probs = torch.softmax(attention_scores.float(), dim=-1) + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) - # compute attention output - x = torch.matmul(attention_probs, value_states) + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) - x = x.permute(0, 2, 1, 3).contiguous() - new_x_shape = x.size()[:-2] + (self.num_channels,) - x = x.view(new_x_shape) + if has_xformers: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) - # compute next hidden states - x = self.proj_attn(x) + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) if self.spatial_dims == 2: x = x.transpose(-1, -2).reshape(batch, channel, height, width) From df7b60746a747bbf7d151674675613e31fda6852 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 26 Dec 2022 15:06:06 +0000 Subject: [PATCH 07/10] Fix xformers import (#145) Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/autoencoderkl.py | 1 - generative/networks/nets/diffusion_model_unet.py | 1 - 2 files changed, 2 deletions(-) diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index 9114293a..d65406f9 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -16,7 +16,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import xformers.ops from monai.networks.blocks import Convolution if importlib.util.find_spec("xformers") is not None: diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 0c67c68b..cc04efe3 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -35,7 +35,6 @@ import torch import torch.nn.functional as F -import xformers.ops from monai.networks.blocks import Convolution from monai.networks.layers.factories import Pool from torch import nn From 1e2001b81610d83f67edf7f83642ec6aee1a36c9 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Fri, 6 Jan 2023 13:49:59 +0000 Subject: [PATCH 08/10] Remove xformers from requirements.txt Signed-off-by: Walter Hugo Lopez Pinaya --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6561ae05..88b7af44 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,3 @@ monai-weekly==1.1.dev2248 numpy>=1.17 torch>=1.8 tqdm -xformers==0.0.16 From 1d7a053ba65b974d562e62c1b9b7b693548b7538 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Fri, 6 Jan 2023 13:55:05 +0000 Subject: [PATCH 09/10] Add instructions to install xformers Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/autoencoderkl.py | 3 ++- generative/networks/nets/diffusion_model_unet.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index d65406f9..eb5bc16b 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -18,6 +18,7 @@ import torch.nn.functional as F from monai.networks.blocks import Convolution +# To install xformers, use pip install xformers==0.0.16rc401 if importlib.util.find_spec("xformers") is not None: import xformers import xformers.ops @@ -27,7 +28,7 @@ xformers = None has_xformers = False -# TODO: Make optional import work +# TODO: Use MONAI's optional_import # from monai.utils import optional_import # xformers, has_xformers = optional_import("xformers.ops", name="xformers") diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index cc04efe3..5a3de884 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -39,6 +39,7 @@ from monai.networks.layers.factories import Pool from torch import nn +# To install xformers, use pip install xformers==0.0.16rc401 if importlib.util.find_spec("xformers") is not None: import xformers import xformers.ops From c4c56c895567d3ee8c757d3573448925ca088cea Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 14 Jan 2023 14:27:49 +0000 Subject: [PATCH 10/10] Add docstrings [#145] Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/autoencoderkl.py | 4 ++++ generative/networks/nets/diffusion_model_unet.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index 51f2eb6b..73ea36ea 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -207,12 +207,16 @@ def __init__( self.proj_attn = nn.Linear(num_channels, num_channels) def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + """ + Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. + """ batch_size, seq_len, dim = x.shape x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) return x def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + """Combine the output of the attention heads back into the hidden state dimension.""" batch_size, seq_len, dim = x.shape x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index d60fd772..271f85e4 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -140,12 +140,16 @@ def __init__( self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + """ + Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. + """ batch_size, seq_len, dim = x.shape x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) return x def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + """Combine the output of the attention heads back into the hidden state dimension.""" batch_size, seq_len, dim = x.shape x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads)