Skip to content
7 changes: 3 additions & 4 deletions monai/networks/blocks/spade_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.blocks import ADN, Convolution
from monai.networks.blocks import Convolution
from monai.networks.layers.utils import get_norm_layer


class SPADE(nn.Module):
Expand Down Expand Up @@ -50,9 +51,7 @@ def __init__(
norm_params = {}
if len(norm_params) != 0:
norm = (norm, norm_params)
self.param_free_norm = ADN(
act=None, dropout=0.0, norm=norm, norm_dim=spatial_dims, ordering="N", in_channels=norm_nc
)
self.param_free_norm = get_norm_layer(norm, spatial_dims=spatial_dims, channels=norm_nc)
self.mlp_shared = Convolution(
spatial_dims=spatial_dims,
in_channels=label_nc,
Expand Down
1 change: 1 addition & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
)
from .spade_autoencoderkl import SPADEAutoencoderKL
from .spade_diffusion_model_unet import SPADEDiffusionModelUNet
from .spade_network import SPADENet
from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR
from .torchvision_fc import TorchVisionFCModel
from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex
Expand Down
8 changes: 4 additions & 4 deletions monai/networks/nets/spade_diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
__all__ = ["SPADEDiffusionModelUNet"]


class SPADEResnetBlock(nn.Module):
class SPADEDiffResBlock(nn.Module):
"""
Residual block with timestep conditioning and SPADE norm.
Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)
Expand Down Expand Up @@ -235,7 +235,7 @@ def __init__(
resnet_in_channels = prev_output_channel if i == 0 else out_channels

resnets.append(
SPADEResnetBlock(
SPADEDiffResBlock(
spatial_dims=spatial_dims,
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
Expand Down Expand Up @@ -353,7 +353,7 @@ def __init__(
resnet_in_channels = prev_output_channel if i == 0 else out_channels

resnets.append(
SPADEResnetBlock(
SPADEDiffResBlock(
spatial_dims=spatial_dims,
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
Expand Down Expand Up @@ -488,7 +488,7 @@ def __init__(
resnet_in_channels = prev_output_channel if i == 0 else out_channels

resnets.append(
SPADEResnetBlock(
SPADEDiffResBlock(
spatial_dims=spatial_dims,
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
Expand Down
Loading