Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import os
import subprocess
import sys
import importlib
import inspect

sys.path.insert(0, os.path.abspath(".."))
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
Expand Down Expand Up @@ -137,7 +139,7 @@ def generate_apidocs(*args):
"github_user": "Project-MONAI",
"github_repo": "MONAI",
"github_version": "dev",
"doc_path": "docs/",
"doc_path": "docs/source",
"conf_py_path": "/docs/",
"VERSION": version,
}
Expand All @@ -162,3 +164,47 @@ def setup(app):
# Hook to allow for automatic generation of API docs
# before doc deployment begins.
app.connect("builder-inited", generate_apidocs)


# -- Linkcode configuration --------------------------------------------------
DEFAULT_REPOSITORY = "Project-MONAI/MONAI"
repository = os.environ.get("GITHUB_REPOSITORY", DEFAULT_REPOSITORY)

base_code_url = f"https://github.com/{repository}/blob/{version}"
MODULE_ROOT_FOLDER = "monai"


# Adjusted from https://github.com/python-websockets/websockets/blob/main/docs/conf.py
def linkcode_resolve(domain, info):
if domain != "py":
raise ValueError(
f"expected domain to be 'py', got {domain}."
"Please adjust linkcode_resolve to either handle this domain or ignore it."
)

mod = importlib.import_module(info["module"])
if "." in info["fullname"]:
objname, attrname = info["fullname"].split(".")
obj = getattr(mod, objname)
try:
# object is a method of a class
obj = getattr(obj, attrname)
except AttributeError:
# object is an attribute of a class
return None
else:
obj = getattr(mod, info["fullname"])

try:
file = inspect.getsourcefile(obj)
source, lineno = inspect.getsourcelines(obj)
except TypeError:
# e.g. object is a typing.Union
return None
file = os.path.relpath(file, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
if not file.startswith(MODULE_ROOT_FOLDER):
# e.g. object is a typing.NewType
return None
start, end = lineno, lineno + len(source) - 1
url = f"{base_code_url}/{file}#L{start}-L{end}"
return url
65 changes: 59 additions & 6 deletions monai/networks/nets/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from collections.abc import Sequence

import torch
import torch.nn.functional as F
from torch import nn

from monai.networks.blocks import Convolution
Expand All @@ -57,7 +56,8 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, chann
strides=1,
kernel_size=3,
padding=1,
conv_only=True,
adn_ordering="A",
act="SWISH",
)

self.blocks = nn.ModuleList([])
Expand All @@ -73,7 +73,8 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, chann
strides=1,
kernel_size=3,
padding=1,
conv_only=True,
adn_ordering="A",
act="SWISH",
)
)

Expand All @@ -85,7 +86,8 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, chann
strides=2,
kernel_size=3,
padding=1,
conv_only=True,
adn_ordering="A",
act="SWISH",
)
)

Expand All @@ -103,11 +105,9 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, chann

def forward(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)

for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)

embedding = self.conv_out(embedding)

Expand Down Expand Up @@ -410,3 +410,56 @@ def forward(
mid_block_res_sample *= conditioning_scale

return down_block_res_samples, mid_block_res_sample

def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
"""
Load a state dict from a ControlNet trained with
[MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).

Args:
old_state_dict: state dict from the old ControlNet model.
"""

new_state_dict = self.state_dict()
# if all keys match, just load the state dict
if all(k in new_state_dict for k in old_state_dict):
print("All keys match, loading state dict.")
self.load_state_dict(old_state_dict)
return

if verbose:
# print all new_state_dict keys that are not in old_state_dict
for k in new_state_dict:
if k not in old_state_dict:
print(f"key {k} not found in old state dict")
# and vice versa
print("----------------------------------------------")
for k in old_state_dict:
if k not in new_state_dict:
print(f"key {k} not found in new state dict")

# copy over all matching keys
for k in new_state_dict:
if k in old_state_dict:
new_state_dict[k] = old_state_dict[k]

# fix the attention blocks
attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k]
for block in attention_blocks:
new_state_dict[f"{block}.attn1.qkv.weight"] = torch.concat(
[
old_state_dict[f"{block}.attn1.to_q.weight"],
old_state_dict[f"{block}.attn1.to_k.weight"],
old_state_dict[f"{block}.attn1.to_v.weight"],
],
dim=0,
)

# projection
new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"]
new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"]

new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"]
new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"]

self.load_state_dict(new_state_dict)
4 changes: 1 addition & 3 deletions monai/visualize/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type

if TYPE_CHECKING:
from matplotlib import cm
from matplotlib import pyplot as plt
else:
plt, _ = optional_import("matplotlib", name="pyplot")
cm, _ = optional_import("matplotlib", name="cm")

__all__ = ["matshow3d", "blend_images"]

Expand Down Expand Up @@ -210,7 +208,7 @@ def blend_images(
image = repeat(image, 3, axis=0)

def get_label_rgb(cmap: str, label: NdarrayOrTensor) -> NdarrayOrTensor:
_cmap = cm.get_cmap(cmap)
_cmap = plt.colormaps.get_cmap(cmap)
label_np, *_ = convert_data_type(label, np.ndarray)
label_rgb_np = _cmap(label_np[0])
label_rgb_np = np.moveaxis(label_rgb_np, -1, 0)[:3]
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ einops
transformers>=4.36.0
mlflow>=1.28.0, <=2.11.3
clearml>=1.10.0rc0
matplotlib!=3.5.0
matplotlib>=3.6.3
tensorboardX
types-PyYAML
pyyaml
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ all =
transformers<4.22; python_version <= '3.10'
mlflow>=1.28.0, <=2.11.3
clearml>=1.10.0rc0
matplotlib
matplotlib>=3.6.3
tensorboardX
pyyaml
fire
Expand Down Expand Up @@ -127,7 +127,7 @@ transformers =
mlflow =
mlflow>=1.28.0, <=2.11.3
matplotlib =
matplotlib
matplotlib>=3.6.3
clearml =
clearml
tensorboardX =
Expand Down
33 changes: 33 additions & 0 deletions tests/test_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@

from __future__ import annotations

import os
import tempfile
import unittest
from unittest import skipUnless

import torch
from parameterized import parameterized

from monai.apps import download_url
from monai.networks import eval_mode
from monai.networks.nets.controlnet import ControlNet
from monai.utils import optional_import
from tests.utils import skip_if_downloading_fails, testing_data_config

_, has_einops = optional_import("einops")
UNCOND_CASES_2D = [
Expand Down Expand Up @@ -177,6 +181,35 @@ def test_shape_conditioned_models(self, input_param, expected_output_shape):
self.assertEqual(len(result[0]), 2 * len(input_param["channels"]))
self.assertEqual(result[1].shape, expected_output_shape)

@skipUnless(has_einops, "Requires einops")
def test_compatibility_with_monai_generative(self):
# test loading weights from a model saved in MONAI Generative, version 0.2.3
with skip_if_downloading_fails():
net = ControlNet(
spatial_dims=2,
in_channels=1,
num_res_blocks=1,
channels=(8, 8, 8),
attention_levels=(False, False, True),
norm_num_groups=8,
with_conditioning=True,
transformer_num_layers=1,
cross_attention_dim=3,
resblock_updown=True,
)

tmpdir = tempfile.mkdtemp()
key = "controlnet_monai_generative_weights"
url = testing_data_config("models", key, "url")
hash_type = testing_data_config("models", key, "hash_type")
hash_val = testing_data_config("models", key, "hash_val")
filename = "controlnet_monai_generative_weights.pt"

weight_path = os.path.join(tmpdir, filename)
download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type)

net.load_old_state_dict(torch.load(weight_path), verbose=False)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions tests/test_matshow3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def test_3d_rgb(self):
every_n=2,
frame_dim=-1,
channel_dim=0,
fill_value=0,
show=False,
)

Expand Down
5 changes: 5 additions & 0 deletions tests/testing_data/data_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@
"url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/autoencoderkl.pth",
"hash_type": "sha256",
"hash_val": "6e02c9540c51b16b9ba98b5c0c75d6b84b430afe9a3237df1d67a520f8d34184"
},
"controlnet_monai_generative_weights": {
"url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/controlnet.pth",
"hash_type": "sha256",
"hash_val": "cd100d0c69f47569ae5b4b7df653a1cb19f5e02eff1630db3210e2646fb1ab2e"
}
},
"configs": {
Expand Down