Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 143 additions & 47 deletions generative/inferers/inferer.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions generative/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,8 @@
from .controlnet import ControlNet
from .diffusion_model_unet import DiffusionModelUNet
from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator
from .spade_autoencoderkl import SPADEAutoencoderKL
from .spade_diffusion_model_unet import SPADEDiffusionModelUNet
from .spade_network import SPADENet
from .transformer import DecoderOnlyTransformer
from .vqvae import VQVAE
38 changes: 18 additions & 20 deletions generative/networks/nets/diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,7 @@ def __init__(
cross_attention_dim: int | None = None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
dropout_cattn: float = 0.0,
) -> None:
super().__init__()
self.resblock_updown = resblock_updown
Expand Down Expand Up @@ -964,7 +964,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout=dropout_cattn
dropout=dropout_cattn,
)
)

Expand Down Expand Up @@ -1103,7 +1103,7 @@ def __init__(
cross_attention_dim: int | None = None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
dropout_cattn: float = 0.0,
) -> None:
super().__init__()
self.attention = None
Expand All @@ -1127,7 +1127,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout=dropout_cattn
dropout=dropout_cattn,
)
self.resnet_2 = ResnetBlock(
spatial_dims=spatial_dims,
Expand Down Expand Up @@ -1271,7 +1271,7 @@ def __init__(
add_upsample: bool = True,
resblock_updown: bool = False,
num_head_channels: int = 1,
use_flash_attention: bool = False
use_flash_attention: bool = False,
) -> None:
super().__init__()
self.resblock_updown = resblock_updown
Expand Down Expand Up @@ -1388,7 +1388,7 @@ def __init__(
cross_attention_dim: int | None = None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
dropout_cattn: float = 0.0,
) -> None:
super().__init__()
self.resblock_updown = resblock_updown
Expand Down Expand Up @@ -1422,7 +1422,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout=dropout_cattn
dropout=dropout_cattn,
)
)

Expand Down Expand Up @@ -1486,7 +1486,7 @@ def get_down_block(
cross_attention_dim: int | None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
dropout_cattn: float = 0.0,
) -> nn.Module:
if with_attn:
return AttnDownBlock(
Expand Down Expand Up @@ -1518,7 +1518,7 @@ def get_down_block(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn
dropout_cattn=dropout_cattn,
)
else:
return DownBlock(
Expand Down Expand Up @@ -1546,7 +1546,7 @@ def get_mid_block(
cross_attention_dim: int | None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
dropout_cattn: float = 0.0,
) -> nn.Module:
if with_conditioning:
return CrossAttnMidBlock(
Expand All @@ -1560,7 +1560,7 @@ def get_mid_block(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn
dropout_cattn=dropout_cattn,
)
else:
return AttnMidBlock(
Expand Down Expand Up @@ -1592,7 +1592,7 @@ def get_up_block(
cross_attention_dim: int | None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
dropout_cattn: float = 0.0,
) -> nn.Module:
if with_attn:
return AttnUpBlock(
Expand Down Expand Up @@ -1626,7 +1626,7 @@ def get_up_block(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn
dropout_cattn=dropout_cattn,
)
else:
return UpBlock(
Expand Down Expand Up @@ -1688,7 +1688,7 @@ def __init__(
num_class_embeds: int | None = None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
dropout_cattn: float = 0.0,
) -> None:
super().__init__()
if with_conditioning is True and cross_attention_dim is None:
Expand All @@ -1701,9 +1701,7 @@ def __init__(
"DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim."
)
if dropout_cattn > 1.0 or dropout_cattn < 0.0:
raise ValueError(
"Dropout cannot be negative or >1.0!"
)
raise ValueError("Dropout cannot be negative or >1.0!")

# All number of channels should be multiple of num_groups
if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):
Expand Down Expand Up @@ -1793,7 +1791,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn
dropout_cattn=dropout_cattn,
)

self.down_blocks.append(down_block)
Expand All @@ -1811,7 +1809,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn
dropout_cattn=dropout_cattn,
)

# up
Expand Down Expand Up @@ -1846,7 +1844,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn
dropout_cattn=dropout_cattn,
)

self.up_blocks.append(up_block)
Expand Down
Loading