Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6a3efe0
initial commit anomaly detection with gradient guidance
JuliaWolleb Jan 16, 2023
a750e53
first reversed loop for DDIM is implemented. Classifier guidance is o…
JuliaWolleb Feb 9, 2023
cc7b704
anomaly detection tutorial is complete, the training needs to be checked
JuliaWolleb Feb 22, 2023
2a0bed9
cleaned up the classification network for gradient guidance
JuliaWolleb Feb 22, 2023
20535f0
cleaning up
JuliaWolleb Feb 23, 2023
04dff72
pull main branch
JuliaWolleb Feb 23, 2023
c0cf8b3
run tests on all files
JuliaWolleb Feb 23, 2023
7bb6d88
create folder for the anomaly detection tutorials
JuliaWolleb Mar 15, 2023
72922f0
remove old folder
JuliaWolleb Mar 15, 2023
b51edeb
pull main branch
JuliaWolleb Mar 15, 2023
0d62a18
autofix after testing
JuliaWolleb Mar 15, 2023
5b2cf1b
move anomaly detection tutorials to the folder /tutorials/generative/…
JuliaWolleb Mar 15, 2023
df74d7d
remove old folder
JuliaWolleb Mar 15, 2023
18ebcb9
Update generative/networks/nets/diffusion_model_unet.py
JuliaWolleb Mar 17, 2023
e9b3978
Update generative/networks/nets/diffusion_model_unet.py
JuliaWolleb Mar 17, 2023
88672f6
Update generative/networks/nets/diffusion_model_unet.py
JuliaWolleb Mar 17, 2023
bc204f3
Update generative/networks/nets/diffusion_model_unet.py
JuliaWolleb Mar 17, 2023
b7c5613
Update generative/networks/nets/diffusion_model_unet.py
JuliaWolleb Mar 17, 2023
614fb9e
Update generative/networks/nets/diffusion_model_unet.py
JuliaWolleb Mar 17, 2023
1cc8fc2
Update generative/networks/nets/diffusion_model_unet.py
JuliaWolleb Mar 20, 2023
55513e1
Update generative/networks/nets/diffusion_model_unet.py
JuliaWolleb Mar 20, 2023
b4b82c3
include Walters changes in the tutorial
JuliaWolleb Mar 21, 2023
28613ac
Merge branch '155-add-tutorial-about-diffusion-models-for-medical-ano…
JuliaWolleb Mar 21, 2023
574444d
pull main branch
JuliaWolleb Mar 21, 2023
62a8fe3
Fix changed files
Warvito Mar 22, 2023
e09220e
Fix changed files
Warvito Mar 22, 2023
4b786f8
Move files and update License
Warvito Mar 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions generative/networks/nets/diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,3 +1890,176 @@ def forward(
h = self.out(h)

return h


class DiffusionModelEncoder(nn.Module):
"""
Classification Network based on the Encoder of the Diffusion Model, followed by fully connected layers. This network is based on
Wolleb et al. "Diffusion Models for Medical Anomaly Detection" (https://arxiv.org/abs/2203.04306).

Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
num_res_blocks: number of residual blocks (see ResnetBlock) per level.
num_channels: tuple of block output channels.
attention_levels: list of levels to add attention.
norm_num_groups: number of groups for the normalization.
norm_eps: epsilon for the normalization.
resblock_updown: if True use residual blocks for downsampling.
num_head_channels: number of channels in each attention head.
with_conditioning: if True add spatial transformers to perform conditioning.
transformer_num_layers: number of layers of Transformer blocks to use.
cross_attention_dim: number of context dimensions to use.
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes.
upcast_attention: if True, upcast attention operations to full precision.
"""

def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
num_channels: Sequence[int] = (32, 64, 64, 64),
attention_levels: Sequence[bool] = (False, False, True, True),
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
resblock_updown: bool = False,
num_head_channels: int | Sequence[int] = 8,
with_conditioning: bool = False,
transformer_num_layers: int = 1,
cross_attention_dim: int | None = None,
num_class_embeds: int | None = None,
upcast_attention: bool = False,
) -> None:
super().__init__()
if with_conditioning is True and cross_attention_dim is None:
raise ValueError(
"DiffusionModelEncoder expects dimension of the cross-attention conditioning (cross_attention_dim) "
"when using with_conditioning."
)
if cross_attention_dim is not None and with_conditioning is False:
raise ValueError(
"DiffusionModelEncoder expects with_conditioning=True when specifying the cross_attention_dim."
)

# All number of channels should be multiple of num_groups
if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):
raise ValueError("DiffusionModelEncoder expects all num_channels being multiple of norm_num_groups")
if len(num_channels) != len(attention_levels):
raise ValueError("DiffusionModelEncoder expects num_channels being same size of attention_levels")

if isinstance(num_head_channels, int):
num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))

if len(num_head_channels) != len(attention_levels):
raise ValueError(
"num_head_channels should have the same length as attention_levels. For the i levels without attention,"
" i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
)

self.in_channels = in_channels
self.block_out_channels = num_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_levels = attention_levels
self.num_head_channels = num_head_channels
self.with_conditioning = with_conditioning

# input
self.conv_in = Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=num_channels[0],
strides=1,
kernel_size=3,
padding=1,
conv_only=True,
)

# time
time_embed_dim = num_channels[0] * 4
self.time_embed = nn.Sequential(
nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)
)

# class embedding
self.num_class_embeds = num_class_embeds
if num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)

# down
self.down_blocks = nn.ModuleList([])
output_channel = num_channels[0]
for i in range(len(num_channels)):
input_channel = output_channel
output_channel = num_channels[i]
is_final_block = i == len(num_channels) # - 1

down_block = get_down_block(
spatial_dims=spatial_dims,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
num_res_blocks=num_res_blocks[i],
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
add_downsample=not is_final_block,
resblock_updown=resblock_updown,
with_attn=(attention_levels[i] and not with_conditioning),
with_cross_attn=(attention_levels[i] and with_conditioning),
num_head_channels=num_head_channels[i],
transformer_num_layers=transformer_num_layers,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
)

self.down_blocks.append(down_block)

self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels))

def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor | None = None,
class_labels: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Args:
x: input tensor (N, C, SpatialDims).
timesteps: timestep tensor (N,).
context: context tensor (N, 1, ContextDim).
class_labels: context tensor (N, ).
"""
# 1. time
t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])

# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=x.dtype)
emb = self.time_embed(t_emb)

# 2. class
if self.num_class_embeds is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
class_emb = self.class_embedding(class_labels)
class_emb = class_emb.to(dtype=x.dtype)
emb = emb + class_emb

# 3. initial convolution
h = self.conv_in(x)

# 4. down
if context is not None and self.with_conditioning is False:
raise ValueError("model should have with_conditioning = True if context is provided")
for downsample_block in self.down_blocks:
h, _ = downsample_block(hidden_states=h, temb=emb, context=context)

h = h.reshape(h.shape[0], -1)
output = self.out(h)

return output
Loading