From 791b28450f572b42c626c674ef59faa7a9276d72 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 13 May 2024 09:58:09 +0100 Subject: [PATCH 1/2] Only have contigous calls after attention blocks Signed-off-by: Mark Graham --- monai/networks/nets/diffusion_model_unet.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 38d7f816a9..f995d20e54 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -115,10 +115,6 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch class SpatialTransformer(nn.Module): """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply standard transformer action. Finally, reshape to image. @@ -396,14 +392,11 @@ def __init__( ) def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: - h = x.contiguous() + h = x h = self.norm1(h) h = self.nonlinearity(h) if self.upsample is not None: - if h.shape[0] >= 64: - x = x.contiguous() - h = h.contiguous() x = self.upsample(x) h = self.upsample(h) elif self.downsample is not None: @@ -609,7 +602,7 @@ def forward( for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) + hidden_states = attn(hidden_states).contiguous() output_states.append(hidden_states) if self.downsampler is not None: @@ -726,7 +719,7 @@ def forward( for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=context) + hidden_states = attn(hidden_states, context=context).contiguous() output_states.append(hidden_states) if self.downsampler is not None: @@ -790,7 +783,7 @@ def forward( ) -> torch.Tensor: del context hidden_states = self.resnet_1(hidden_states, temb) - hidden_states = self.attention(hidden_states) + hidden_states = self.attention(hidden_states).contiguous() hidden_states = self.resnet_2(hidden_states, temb) return hidden_states @@ -1091,7 +1084,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) + hidden_states = attn(hidden_states).contiguous() if self.upsampler is not None: hidden_states = self.upsampler(hidden_states, temb) @@ -1669,7 +1662,7 @@ def forward( down_block_res_samples = new_down_block_res_samples # 5. mid - h = self.middle_block(hidden_states=h.contiguous(), temb=emb, context=context) + h = self.middle_block(hidden_states=h, temb=emb, context=context) # Additional residual conections for Controlnets if mid_block_additional_residual is not None: @@ -1682,7 +1675,7 @@ def forward( h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) # 7. output block - output: torch.Tensor = self.out(h.contiguous()) + output: torch.Tensor = self.out(h) return output From 207854a7700a1898ce082f6a49e29cfea368970a Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 14 May 2024 09:27:13 +0100 Subject: [PATCH 2/2] Fixes for diffusion model unet Signed-off-by: Mark Graham --- monai/networks/nets/spade_diffusion_model_unet.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py index e019d21c11..594b8068af 100644 --- a/monai/networks/nets/spade_diffusion_model_unet.py +++ b/monai/networks/nets/spade_diffusion_model_unet.py @@ -170,9 +170,6 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor, seg: torch.Tensor) -> torc h = self.nonlinearity(h) if self.upsample is not None: - if h.shape[0] >= 64: - x = x.contiguous() - h = h.contiguous() x = self.upsample(x) h = self.upsample(h) elif self.downsample is not None: @@ -430,7 +427,7 @@ def forward( res_hidden_states_list = res_hidden_states_list[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb, seg) - hidden_states = attn(hidden_states) + hidden_states = attn(hidden_states).contiguous() if self.upsampler is not None: hidden_states = self.upsampler(hidden_states, temb) @@ -568,7 +565,7 @@ def forward( res_hidden_states_list = res_hidden_states_list[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb, seg) - hidden_states = attn(hidden_states, context=context) + hidden_states = attn(hidden_states, context=context).contiguous() if self.upsampler is not None: hidden_states = self.upsampler(hidden_states, temb) @@ -919,7 +916,7 @@ def forward( down_block_res_samples = new_down_block_res_samples # 5. mid - h = self.middle_block(hidden_states=h.contiguous(), temb=emb, context=context) + h = self.middle_block(hidden_states=h, temb=emb, context=context) # Additional residual conections for Controlnets if mid_block_additional_residual is not None: