From a8c2a1adc947248c377b69721d2fac7cc0840e6c Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Wed, 1 Nov 2023 16:06:00 +0000 Subject: [PATCH 01/67] 5394 enable py311 tests (#7182) Fixes https://github.com/Project-MONAI/MONAI/issues/5394 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li Signed-off-by: Mark Graham --- tests/utils.py | 54 ++++++++++++++++++++------------------------------ 1 file changed, 21 insertions(+), 33 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 4bda8955ed..cf1711292f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -724,22 +724,16 @@ def test_script_save(net, *inputs, device=None, rtol=1e-4, atol=0.0): """ # TODO: would be nice to use GPU if available, but it currently causes CI failures. device = "cpu" - try: - with tempfile.TemporaryDirectory() as tempdir: - convert_to_torchscript( - model=net, - filename_or_obj=os.path.join(tempdir, "model.ts"), - verify=True, - inputs=inputs, - device=device, - rtol=rtol, - atol=atol, - ) - except (RuntimeError, AttributeError): - if sys.version_info.major == 3 and sys.version_info.minor == 11: - warnings.warn("skipping py 3.11") - return - raise + with tempfile.TemporaryDirectory() as tempdir: + convert_to_torchscript( + model=net, + filename_or_obj=os.path.join(tempdir, "model.ts"), + verify=True, + inputs=inputs, + device=device, + rtol=rtol, + atol=atol, + ) def test_onnx_save(net, *inputs, device=None, rtol=1e-4, atol=0.0): @@ -753,23 +747,17 @@ def test_onnx_save(net, *inputs, device=None, rtol=1e-4, atol=0.0): # TODO: would be nice to use GPU if available, but it currently causes CI failures. device = "cpu" _, has_onnxruntime = optional_import("onnxruntime") - try: - with tempfile.TemporaryDirectory() as tempdir: - convert_to_onnx( - model=net, - filename=os.path.join(tempdir, "model.onnx"), - verify=True, - inputs=inputs, - device=device, - use_ort=has_onnxruntime, - rtol=rtol, - atol=atol, - ) - except (RuntimeError, AttributeError): - if sys.version_info.major == 3 and sys.version_info.minor == 11: - warnings.warn("skipping py 3.11") - return - raise + with tempfile.TemporaryDirectory() as tempdir: + convert_to_onnx( + model=net, + filename=os.path.join(tempdir, "model.onnx"), + verify=True, + inputs=inputs, + device=device, + use_ort=has_onnxruntime, + rtol=rtol, + atol=atol, + ) def download_url_or_skip_test(*args, **kwargs): From 40aaea3358462ea315e9a5056b6ef77d86e67eba Mon Sep 17 00:00:00 2001 From: Dong Yang Date: Thu, 2 Nov 2023 03:55:38 -0600 Subject: [PATCH 02/67] [Auto3DSeg] Add mlflow support in autorunner. (#7176) Add MLflow support in AutoRunner Class. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: dongy Signed-off-by: Mark Graham --- monai/apps/auto3dseg/auto_runner.py | 5 +++++ monai/apps/auto3dseg/bundle_gen.py | 34 +++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/monai/apps/auto3dseg/auto_runner.py b/monai/apps/auto3dseg/auto_runner.py index 23fb3eb250..80ae34180e 100644 --- a/monai/apps/auto3dseg/auto_runner.py +++ b/monai/apps/auto3dseg/auto_runner.py @@ -83,6 +83,8 @@ class AutoRunner: zip url will be downloaded and extracted into the work_dir. allow_skip: a switch passed to BundleGen process which determines if some Algo in the default templates can be skipped based on the analysis on the dataset from Auto3DSeg DataAnalyzer. + mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of the remote + tracking Server; MLflow runs will be recorded locally in algorithms' model folder if the value is None. kwargs: image writing parameters for the ensemble inference. The kwargs format follows the SaveImage transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage. @@ -209,6 +211,7 @@ def __init__( not_use_cache: bool = False, templates_path_or_url: str | None = None, allow_skip: bool = True, + mlflow_tracking_uri: str | None = None, **kwargs: Any, ): logger.info(f"AutoRunner using work directory {work_dir}") @@ -220,6 +223,7 @@ def __init__( self.algos = algos self.templates_path_or_url = templates_path_or_url self.allow_skip = allow_skip + self.mlflow_tracking_uri = mlflow_tracking_uri self.kwargs = deepcopy(kwargs) if input is None and os.path.isfile(self.data_src_cfg_name): @@ -783,6 +787,7 @@ def run(self): templates_path_or_url=self.templates_path_or_url, data_stats_filename=self.datastats_filename, data_src_cfg_name=self.data_src_cfg_name, + mlflow_tracking_uri=self.mlflow_tracking_uri, ) if self.gpu_customization: diff --git a/monai/apps/auto3dseg/bundle_gen.py b/monai/apps/auto3dseg/bundle_gen.py index a091739dd3..03b9c8bbf4 100644 --- a/monai/apps/auto3dseg/bundle_gen.py +++ b/monai/apps/auto3dseg/bundle_gen.py @@ -85,6 +85,7 @@ def __init__(self, template_path: PathLike): self.template_path = template_path self.data_stats_files = "" self.data_list_file = "" + self.mlflow_tracking_uri = None self.output_path = "" self.name = "" self.best_metric = None @@ -129,6 +130,17 @@ def set_data_source(self, data_src_cfg: str) -> None: """ self.data_list_file = data_src_cfg + def set_mlflow_tracking_uri(self, mlflow_tracking_uri: str | None) -> None: + """ + Set the tracking URI for MLflow server + + Args: + mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of + the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if + the value is None. + """ + self.mlflow_tracking_uri = mlflow_tracking_uri # type: ignore + def fill_template_config(self, data_stats_filename: str, algo_path: str, **kwargs: Any) -> dict: """ The configuration files defined when constructing this Algo instance might not have a complete training @@ -432,6 +444,9 @@ class BundleGen(AlgoGen): data_stats_filename: the path to the data stats file (generated by DataAnalyzer). data_src_cfg_name: the path to the data source config YAML file. The config will be in a form of {"modality": "ct", "datalist": "path_to_json_datalist", "dataroot": "path_dir_data"}. + mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of + the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if + the value is None. .. code-block:: bash python -m monai.apps.auto3dseg BundleGen generate --data_stats_filename="../algorithms/datastats.yaml" @@ -444,6 +459,7 @@ def __init__( templates_path_or_url: str | None = None, data_stats_filename: str | None = None, data_src_cfg_name: str | None = None, + mlflow_tracking_uri: str | None = None, ): if algos is None or isinstance(algos, (list, tuple, str)): if templates_path_or_url is None: @@ -496,6 +512,7 @@ def __init__( self.data_stats_filename = data_stats_filename self.data_src_cfg_name = data_src_cfg_name + self.mlflow_tracking_uri = mlflow_tracking_uri self.history: list[dict] = [] def set_data_stats(self, data_stats_filename: str) -> None: @@ -524,6 +541,21 @@ def get_data_src(self): """Get the data source filename""" return self.data_src_cfg_name + def set_mlflow_tracking_uri(self, mlflow_tracking_uri): + """ + Set the tracking URI for MLflow server + + Args: + mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of + the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if + the value is None. + """ + self.mlflow_tracking_uri = mlflow_tracking_uri + + def get_mlflow_tracking_uri(self): + """Get the tracking URI for MLflow server""" + return self.mlflow_tracking_uri + def get_history(self) -> list: """Get the history of the bundleAlgo object with their names/identifiers""" return self.history @@ -575,9 +607,11 @@ def generate( for f_id in ensure_tuple(fold_idx): data_stats = self.get_data_stats() data_src_cfg = self.get_data_src() + mlflow_tracking_uri = self.get_mlflow_tracking_uri() gen_algo = deepcopy(algo) gen_algo.set_data_stats(data_stats) gen_algo.set_data_source(data_src_cfg) + gen_algo.set_mlflow_tracking_uri(mlflow_tracking_uri) name = f"{gen_algo.name}_{f_id}" if allow_skip: From 2971fcd08776630829466d82d7e3780ea695a85d Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Thu, 2 Nov 2023 15:08:22 +0000 Subject: [PATCH 03/67] 6627 reading dcm folder filtering filenames (#7181) Fixes #6627 ### Description adding a `fname_regex` option to the pydicom reader ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li Signed-off-by: Mark Graham --- monai/apps/datasets.py | 9 ++++++++- monai/apps/tcia/__init__.py | 9 ++++++++- monai/apps/tcia/utils.py | 12 ++++++++++-- monai/data/image_reader.py | 22 ++++++++++++++++++---- tests/test_load_image.py | 5 +++++ tests/test_tciadataset.py | 25 ++++++++++++++++++++++--- 6 files changed, 71 insertions(+), 11 deletions(-) diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index 54829cb946..bb10eb6b11 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -22,6 +22,7 @@ import numpy as np from monai.apps.tcia import ( + DCM_FILENAME_REGEX, download_tcia_series_instance, get_tcia_metadata, get_tcia_ref_uid, @@ -442,6 +443,10 @@ class TciaDataset(Randomizable, CacheDataset): specific_tags: tags that will be loaded for "SEG" series. This argument will be used in `monai.data.PydicomReader`. Default is [(0x0008, 0x1115), (0x0008,0x1140), (0x3006, 0x0010), (0x0020,0x000D), (0x0010,0x0010), (0x0010,0x0020), (0x0020,0x0011), (0x0020,0x0012)]. + fname_regex: a regular expression to match the file names when the input is a folder. + If provided, only the matched files will be included. For example, to include the file name + "image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`. + Default to `"^(?!.*LICENSE).*"`, ignoring any file name containing `"LICENSE"`. val_frac: percentage of validation fraction in the whole dataset, default is 0.2. seed: random seed to randomly shuffle the datalist before splitting into training and validation, default is 0. note to set same seed for `training` and `validation` sections. @@ -509,6 +514,7 @@ def __init__( (0x0020, 0x0011), # Series Number (0x0020, 0x0012), # Acquisition Number ), + fname_regex: str = DCM_FILENAME_REGEX, seed: int = 0, val_frac: float = 0.2, cache_num: int = sys.maxsize, @@ -548,12 +554,13 @@ def __init__( if not os.path.exists(download_dir): raise RuntimeError(f"Cannot find dataset directory: {download_dir}.") + self.fname_regex = fname_regex self.indices: np.ndarray = np.array([]) self.datalist = self._generate_data_list(download_dir) if transform == (): - transform = LoadImaged(reader="PydicomReader", keys=["image"]) + transform = LoadImaged(keys=["image"], reader="PydicomReader", fname_regex=self.fname_regex) CacheDataset.__init__( self, data=self.datalist, diff --git a/monai/apps/tcia/__init__.py b/monai/apps/tcia/__init__.py index af3d44fd14..e33d4abfbc 100644 --- a/monai/apps/tcia/__init__.py +++ b/monai/apps/tcia/__init__.py @@ -12,4 +12,11 @@ from __future__ import annotations from .label_desc import TCIA_LABEL_DICT -from .utils import download_tcia_series_instance, get_tcia_metadata, get_tcia_ref_uid, match_tcia_ref_uid_in_study +from .utils import ( + BASE_URL, + DCM_FILENAME_REGEX, + download_tcia_series_instance, + get_tcia_metadata, + get_tcia_ref_uid, + match_tcia_ref_uid_in_study, +) diff --git a/monai/apps/tcia/utils.py b/monai/apps/tcia/utils.py index 9c120f0072..5524b488e9 100644 --- a/monai/apps/tcia/utils.py +++ b/monai/apps/tcia/utils.py @@ -21,10 +21,18 @@ requests_get, has_requests = optional_import("requests", name="get") pd, has_pandas = optional_import("pandas") -__all__ = ["get_tcia_metadata", "download_tcia_series_instance", "get_tcia_ref_uid", "match_tcia_ref_uid_in_study"] - +DCM_FILENAME_REGEX = r"^(?!.*LICENSE).*" # excluding the file with "LICENSE" in its name BASE_URL = "https://services.cancerimagingarchive.net/nbia-api/services/v1/" +__all__ = [ + "get_tcia_metadata", + "download_tcia_series_instance", + "get_tcia_ref_uid", + "match_tcia_ref_uid_in_study", + "DCM_FILENAME_REGEX", + "BASE_URL", +] + def get_tcia_metadata(query: str, attribute: str | None = None) -> list: """ diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 4c7f2c8c3b..fe199d9570 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -13,6 +13,7 @@ import glob import os +import re import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Sequence @@ -403,8 +404,12 @@ class PydicomReader(ImageReader): label_dict: label of the dicom data. If provided, it will be used when loading segmentation data. Keys of the dict are the classes, and values are the corresponding class number. For example: for TCIA collection "C4KC-KiTS", it can be: {"Kidney": 0, "Renal Tumor": 1}. + fname_regex: a regular expression to match the file names when the input is a folder. + If provided, only the matched files will be included. For example, to include the file name + "image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`. Default to `""`. + Set it to `None` to use `pydicom.misc.is_dicom` to match valid files. kwargs: additional args for `pydicom.dcmread` API. more details about available args: - https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html#pydicom.filereader.dcmread + https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html If the `get_data` function will be called (for example, when using this reader with `monai.transforms.LoadImage`), please ensure that the argument `stop_before_pixels` is `True`, and `specific_tags` covers all necessary tags, such as `PixelSpacing`, @@ -418,6 +423,7 @@ def __init__( swap_ij: bool = True, prune_metadata: bool = True, label_dict: dict | None = None, + fname_regex: str = "", **kwargs, ): super().__init__() @@ -427,6 +433,7 @@ def __init__( self.swap_ij = swap_ij self.prune_metadata = prune_metadata self.label_dict = label_dict + self.fname_regex = fname_regex def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ @@ -467,9 +474,16 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): name = f"{name}" if Path(name).is_dir(): # read DICOM series - series_slcs = glob.glob(os.path.join(name, "*")) - series_slcs = [slc for slc in series_slcs if "LICENSE" not in slc] - slices = [pydicom.dcmread(fp=slc, **kwargs_) for slc in series_slcs] + if self.fname_regex is not None: + series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if re.match(self.fname_regex, slc)] + else: + series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if pydicom.misc.is_dicom(slc)] + slices = [] + for slc in series_slcs: + try: + slices.append(pydicom.dcmread(fp=slc, **kwargs_)) + except pydicom.errors.InvalidDicomError as e: + warnings.warn(f"Failed to read {slc} with exception: \n{e}.", stacklevel=2) img_.append(slices if len(slices) > 1 else slices[0]) if len(slices) > 1: self.has_series = True diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 6f29e7ac50..b6a10bceb4 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -226,6 +226,11 @@ def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape, e ) self.assertTupleEqual(result.shape, expected_np_shape) + def test_no_files(self): + with self.assertRaisesRegex(RuntimeError, "list index out of range"): # fname_regex excludes everything + LoadImage(image_only=True, reader="PydicomReader", fname_regex=r"^(?!.*).*")("tests/testing_data/CT_DICOM") + LoadImage(image_only=True, reader="PydicomReader", fname_regex=None)("tests/testing_data/CT_DICOM") + def test_itk_dicom_series_reader_single(self): result = LoadImage(image_only=True, reader="ITKReader")(self.data_dir) self.assertEqual(result.meta["filename_or_obj"], f"{Path(self.data_dir)}") diff --git a/tests/test_tciadataset.py b/tests/test_tciadataset.py index 7a14262587..2a3928f9aa 100644 --- a/tests/test_tciadataset.py +++ b/tests/test_tciadataset.py @@ -16,7 +16,7 @@ import unittest from monai.apps import TciaDataset -from monai.apps.tcia import TCIA_LABEL_DICT +from monai.apps.tcia import DCM_FILENAME_REGEX, TCIA_LABEL_DICT from monai.data import MetaTensor from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ScaleIntensityd from tests.utils import skip_if_downloading_fails, skip_if_quick @@ -32,7 +32,12 @@ def test_values(self): transform = Compose( [ - LoadImaged(keys=["image", "seg"], reader="PydicomReader", label_dict=TCIA_LABEL_DICT[collection]), + LoadImaged( + keys=["image", "seg"], + reader="PydicomReader", + fname_regex=DCM_FILENAME_REGEX, + label_dict=TCIA_LABEL_DICT[collection], + ), EnsureChannelFirstd(keys="image", channel_dim="no_channel"), ScaleIntensityd(keys="image"), ] @@ -82,10 +87,24 @@ def _test_dataset(dataset): self.assertTupleEqual(data[0]["image"].shape, (256, 256, 24)) self.assertEqual(len(data), int(download_len * val_frac)) data = TciaDataset( - root_dir=testing_dir, collection=collection, section="validation", download=False, val_frac=val_frac + root_dir=testing_dir, + collection=collection, + section="validation", + download=False, + fname_regex=DCM_FILENAME_REGEX, + val_frac=val_frac, ) self.assertTupleEqual(data[0]["image"].shape, (256, 256, 24)) self.assertEqual(len(data), download_len) + with self.assertWarns(UserWarning): + data = TciaDataset( + root_dir=testing_dir, + collection=collection, + section="validation", + fname_regex=".*", # all files including 'LICENSE' is not a valid input + download=False, + val_frac=val_frac, + )[0] shutil.rmtree(os.path.join(testing_dir, collection)) try: From 3fc19937712078d984d7c797e7119ada954c3a6a Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 3 Nov 2023 00:25:03 +0800 Subject: [PATCH 04/67] Precision issue in `get_confusion_matrix ` (#7187) Fixes #7186 ### Description remove unnecessary float() ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu Co-authored-by: Wenqi Li <831580+wyli@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/metrics/confusion_matrix.py | 13 ++++------- monai/metrics/f_beta_score.py | 13 ++++------- monai/metrics/meaniou.py | 3 --- monai/metrics/regression.py | 12 ---------- monai/metrics/surface_dice.py | 3 --- monai/metrics/utils.py | 32 +++++++++++++------------- tests/test_compute_confusion_matrix.py | 15 ++++++++++++ 7 files changed, 41 insertions(+), 50 deletions(-) diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py index 9083b7733f..26ec823081 100644 --- a/monai/metrics/confusion_matrix.py +++ b/monai/metrics/confusion_matrix.py @@ -153,9 +153,6 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) - y = y.float() - y_pred = y_pred.float() - if y.shape != y_pred.shape: raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") @@ -165,12 +162,12 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou # As for classification tasks, S equals to 1. y_pred = y_pred.reshape(batch_size, n_class, -1) y = y.reshape(batch_size, n_class, -1) - tp = ((y_pred + y) == 2).float() - tn = ((y_pred + y) == 0).float() + tp = (y_pred + y) == 2 + tn = (y_pred + y) == 0 - tp = tp.sum(dim=[2]) - tn = tn.sum(dim=[2]) - p = y.sum(dim=[2]) + tp = tp.sum(dim=[2]).float() + tn = tn.sum(dim=[2]).float() + p = y.sum(dim=[2]).float() n = y.shape[-1] - p fn = p - tp diff --git a/monai/metrics/f_beta_score.py b/monai/metrics/f_beta_score.py index a5703105a2..61e4525662 100644 --- a/monai/metrics/f_beta_score.py +++ b/monai/metrics/f_beta_score.py @@ -63,9 +63,6 @@ def get_f_beta_score(y_pred: torch.Tensor, y: torch.Tensor, include_background: if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) - y = y.float() - y_pred = y_pred.float() - if y.shape != y_pred.shape: raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") @@ -75,12 +72,12 @@ def get_f_beta_score(y_pred: torch.Tensor, y: torch.Tensor, include_background: # As for classification tasks, S equals to 1. y_pred = y_pred.view(batch_size, n_class, -1) y = y.view(batch_size, n_class, -1) - tp = ((y_pred + y) == 2).float() - tn = ((y_pred + y) == 0).float() + tp = (y_pred + y) == 2 + tn = (y_pred + y) == 0 - tp = tp.sum(dim=[2]) - tn = tn.sum(dim=[2]) - p = y.sum(dim=[2]) + tp = tp.sum(dim=[2]).float() + tn = tn.sum(dim=[2]).float() + p = y.sum(dim=[2]).float() n = y.shape[-1] - p fn = p - tp diff --git a/monai/metrics/meaniou.py b/monai/metrics/meaniou.py index 7d1ae49f25..65c53f7aa5 100644 --- a/monai/metrics/meaniou.py +++ b/monai/metrics/meaniou.py @@ -130,9 +130,6 @@ def compute_iou( if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) - y = y.float() - y_pred = y_pred.float() - if y.shape != y_pred.shape: raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index f37230f09e..9d29654ee3 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -111,9 +111,6 @@ def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_ self.sq_func = partial(torch.pow, exponent=2.0) def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - y_pred = y_pred.float() - y = y.float() - return compute_mean_error_metrics(y_pred, y, func=self.sq_func) @@ -143,9 +140,6 @@ def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_ self.abs_func = torch.abs def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - y_pred = y_pred.float() - y = y.float() - return compute_mean_error_metrics(y_pred, y, func=self.abs_func) @@ -176,9 +170,6 @@ def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_ self.sq_func = partial(torch.pow, exponent=2.0) def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - y_pred = y_pred.float() - y = y.float() - mse_out = compute_mean_error_metrics(y_pred, y, func=self.sq_func) return torch.sqrt(mse_out) @@ -218,9 +209,6 @@ def __init__( self.sq_func = partial(torch.pow, exponent=2.0) def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> Any: - y_pred = y_pred.float() - y = y.float() - mse_out = compute_mean_error_metrics(y_pred, y, func=self.sq_func) return 20 * math.log10(self.max_val) - 10 * torch.log10(mse_out) diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index f8c402a756..635eb1bc24 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -228,9 +228,6 @@ def compute_surface_dice( f"y_pred and y should have same shape, but instead, shapes are {y_pred.shape} (y_pred) and {y.shape} (y)." ) - y = y.float() - y_pred = y_pred.float() - batch_size, n_class = y_pred.shape[:2] if n_class != len(class_thresholds): diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index fe145b0f50..c139fc35ed 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -95,37 +95,37 @@ def do_metric_reduction( # some elements might be Nan (if ground truth y was missing (zeros)) # we need to account for it nans = torch.isnan(f) - not_nans = (~nans).float() + not_nans = ~nans - t_zero = torch.zeros(1, device=f.device, dtype=f.dtype) + t_zero = torch.zeros(1, device=f.device, dtype=torch.float) reduction = look_up_option(reduction, MetricReduction) if reduction == MetricReduction.NONE: - return f, not_nans + return f, not_nans.float() f[nans] = 0 if reduction == MetricReduction.MEAN: # 2 steps, first, mean by channel (accounting for nans), then by batch - not_nans = not_nans.sum(dim=1) - f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average + not_nans = not_nans.sum(dim=1).float() + f = torch.where(not_nans > 0, f.sum(dim=1).float() / not_nans, t_zero) # channel average - not_nans = (not_nans > 0).float().sum(dim=0) - f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average + not_nans = (not_nans > 0).sum(dim=0).float() + f = torch.where(not_nans > 0, f.sum(dim=0).float() / not_nans, t_zero) # batch average elif reduction == MetricReduction.SUM: - not_nans = not_nans.sum(dim=[0, 1]) + not_nans = not_nans.sum(dim=[0, 1]).float() f = torch.sum(f, dim=[0, 1]) # sum over the batch and channel dims elif reduction == MetricReduction.MEAN_BATCH: - not_nans = not_nans.sum(dim=0) - f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average + not_nans = not_nans.sum(dim=0).float() + f = torch.where(not_nans > 0, f.sum(dim=0).float() / not_nans, t_zero) # batch average elif reduction == MetricReduction.SUM_BATCH: - not_nans = not_nans.sum(dim=0) - f = f.sum(dim=0) # the batch sum + not_nans = not_nans.sum(dim=0).float() + f = f.sum(dim=0).float() # the batch sum elif reduction == MetricReduction.MEAN_CHANNEL: - not_nans = not_nans.sum(dim=1) - f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average + not_nans = not_nans.sum(dim=1).float() + f = torch.where(not_nans > 0, f.sum(dim=1).float() / not_nans, t_zero) # channel average elif reduction == MetricReduction.SUM_CHANNEL: - not_nans = not_nans.sum(dim=1) - f = f.sum(dim=1) # the channel sum + not_nans = not_nans.sum(dim=1).float() + f = f.sum(dim=1).float() # the channel sum elif reduction != MetricReduction.NONE: raise ValueError( f"Unsupported reduction: {reduction}, available options are " diff --git a/tests/test_compute_confusion_matrix.py b/tests/test_compute_confusion_matrix.py index a886d8b7e4..e0a92aec67 100644 --- a/tests/test_compute_confusion_matrix.py +++ b/tests/test_compute_confusion_matrix.py @@ -210,6 +210,14 @@ TEST_CASES_CLF = [data_clf.copy(), result_clf] +TEST_CASE_PRECISION = [ + { + "y_pred": torch.zeros([1, 1, 1024, 1024, 44], device=_device), + "y": torch.zeros([1, 1, 1024, 1024, 44], device=_device), + }, + torch.tensor([[[0.0, 0.0, 46137344.0, 0.0]]]), +] + class TestConfusionMatrix(unittest.TestCase): @parameterized.expand([TEST_CASE_CONFUSION_MATRIX]) @@ -274,6 +282,13 @@ def test_clf_with_nan(self, input_data, expected_value): expected_value = compute_confusion_matrix_metric("tpr", expected_value) assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4) + @parameterized.expand([TEST_CASE_PRECISION]) + def test_precision(self, input_data, expected_value): + # include or ignore background + result = get_confusion_matrix(**input_data) + assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4) + np.testing.assert_equal(result.device, input_data["y_pred"].device) + if __name__ == "__main__": unittest.main() From 3572fd5ca269fbf7297dec7fbed73c2a3d292915 Mon Sep 17 00:00:00 2001 From: myron Date: Fri, 3 Nov 2023 02:54:04 -0700 Subject: [PATCH 05/67] autorunner params from config (#7175) allows setting AutoRunner params from config allows specifying number of folds in config --------- Signed-off-by: myron Signed-off-by: Mark Graham --- monai/apps/auto3dseg/auto_runner.py | 79 +++++++++++++++++++---------- tests/test_vis_gradcam.py | 3 +- 2 files changed, 54 insertions(+), 28 deletions(-) diff --git a/monai/apps/auto3dseg/auto_runner.py b/monai/apps/auto3dseg/auto_runner.py index 80ae34180e..e4c2d908b7 100644 --- a/monai/apps/auto3dseg/auto_runner.py +++ b/monai/apps/auto3dseg/auto_runner.py @@ -214,22 +214,11 @@ def __init__( mlflow_tracking_uri: str | None = None, **kwargs: Any, ): - logger.info(f"AutoRunner using work directory {work_dir}") - os.makedirs(work_dir, exist_ok=True) - - self.work_dir = os.path.abspath(work_dir) - self.data_src_cfg = dict() - self.data_src_cfg_name = os.path.join(self.work_dir, "input.yaml") - self.algos = algos - self.templates_path_or_url = templates_path_or_url - self.allow_skip = allow_skip - self.mlflow_tracking_uri = mlflow_tracking_uri - self.kwargs = deepcopy(kwargs) - - if input is None and os.path.isfile(self.data_src_cfg_name): - input = self.data_src_cfg_name + if input is None and os.path.isfile(os.path.join(os.path.abspath(work_dir), "input.yaml")): + input = os.path.join(os.path.abspath(work_dir), "input.yaml") logger.info(f"Input config is not provided, using the default {input}") + self.data_src_cfg = dict() if isinstance(input, dict): self.data_src_cfg = input elif isinstance(input, str) and os.path.isfile(input): @@ -238,6 +227,51 @@ def __init__( else: raise ValueError(f"{input} is not a valid file or dict") + if "work_dir" in self.data_src_cfg: # override from config + work_dir = self.data_src_cfg["work_dir"] + self.work_dir = os.path.abspath(work_dir) + + logger.info(f"AutoRunner using work directory {self.work_dir}") + os.makedirs(self.work_dir, exist_ok=True) + self.data_src_cfg_name = os.path.join(self.work_dir, "input.yaml") + + self.algos = algos + self.templates_path_or_url = templates_path_or_url + self.allow_skip = allow_skip + + # cache.yaml + self.not_use_cache = not_use_cache + self.cache_filename = os.path.join(self.work_dir, "cache.yaml") + self.cache = self.read_cache() + self.export_cache() + + # determine if we need to analyze, algo_gen or train from cache, unless manually provided + self.analyze = not self.cache["analyze"] if analyze is None else analyze + self.algo_gen = not self.cache["algo_gen"] if algo_gen is None else algo_gen + self.train = train + self.ensemble = ensemble # last step, no need to check + self.hpo = hpo and has_nni + self.hpo_backend = hpo_backend + self.mlflow_tracking_uri = mlflow_tracking_uri + self.kwargs = deepcopy(kwargs) + + # parse input config for AutoRunner param overrides + for param in [ + "analyze", + "algo_gen", + "train", + "hpo", + "ensemble", + "not_use_cache", + "allow_skip", + ]: # override from config + if param in self.data_src_cfg and isinstance(self.data_src_cfg[param], bool): + setattr(self, param, self.data_src_cfg[param]) # e.g. self.analyze = self.data_src_cfg["analyze"] + + for param in ["algos", "hpo_backend", "templates_path_or_url", "mlflow_tracking_uri"]: # override from config + if param in self.data_src_cfg: + setattr(self, param, self.data_src_cfg[param]) # e.g. self.algos = self.data_src_cfg["algos"] + missing_keys = {"dataroot", "datalist", "modality"}.difference(self.data_src_cfg.keys()) if len(missing_keys) > 0: raise ValueError(f"Config keys are missing {missing_keys}") @@ -256,6 +290,8 @@ def __init__( # inspect and update folds num_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename) + if "num_fold" in self.data_src_cfg: + num_fold = int(self.data_src_cfg["num_fold"]) # override from config self.data_src_cfg["datalist"] = datalist_filename # update path to a version in work_dir and save user input ConfigParser.export_config_file( @@ -266,17 +302,6 @@ def __init__( self.datastats_filename = os.path.join(self.work_dir, "datastats.yaml") self.datalist_filename = datalist_filename - self.not_use_cache = not_use_cache - self.cache_filename = os.path.join(self.work_dir, "cache.yaml") - self.cache = self.read_cache() - self.export_cache() - - # determine if we need to analyze, algo_gen or train from cache, unless manually provided - self.analyze = not self.cache["analyze"] if analyze is None else analyze - self.algo_gen = not self.cache["algo_gen"] if algo_gen is None else algo_gen - self.train = train - self.ensemble = ensemble # last step, no need to check - self.set_training_params() self.set_device_info() self.set_prediction_params() @@ -288,9 +313,9 @@ def __init__( self.gpu_customization_specs: dict[str, Any] = {} # hpo - if hpo_backend.lower() != "nni": + if self.hpo_backend.lower() != "nni": raise NotImplementedError("HPOGen backend only supports NNI") - self.hpo = hpo and has_nni + self.hpo = self.hpo and has_nni self.set_hpo_params() self.search_space: dict[str, dict[str, Any]] = {} self.hpo_tasks = 0 diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py index f5ba188082..4b554de0aa 100644 --- a/tests/test_vis_gradcam.py +++ b/tests/test_vis_gradcam.py @@ -20,7 +20,7 @@ from monai.networks.nets import DenseNet, DenseNet121, SEResNet50 from monai.visualize import GradCAM, GradCAMpp -from tests.utils import assert_allclose +from tests.utils import assert_allclose, skip_if_quick class DenseNetAdjoint(DenseNet121): @@ -147,6 +147,7 @@ def __call__(self, x, adjoint_info): TESTS_ILL.append([cam]) +@skip_if_quick class TestGradientClassActivationMap(unittest.TestCase): @parameterized.expand(TESTS) def test_shape(self, cam_class, input_data, expected_shape): From d8b6676c8eb92b3b4f82d5fc55c1ad75399d39c3 Mon Sep 17 00:00:00 2001 From: Kaibo Tang <99367900+kvttt@users.noreply.github.com> Date: Fri, 3 Nov 2023 07:25:54 -0400 Subject: [PATCH 06/67] Implemented VoxelMorph (#7178) Fixes #5484. ### Description Implemented VoxelMorph and added some docstrings. Checked coding style locally. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: kaibo Signed-off-by: Mark Graham --- docs/source/networks.rst | 12 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/voxelmorph.py | 481 ++++++++++++++++++++++++++++++ tests/test_voxelmorph.py | 281 +++++++++++++++++ 4 files changed, 775 insertions(+) create mode 100644 monai/networks/nets/voxelmorph.py create mode 100644 tests/test_voxelmorph.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 7b7888732f..8eada7933f 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -710,6 +710,18 @@ Nets .. autoclass:: Quicknat :members: +`VoxelMorph` +~~~~~~~~~~~~ +.. autoclass:: VoxelMorphUNet + :members: + +.. autoclass:: voxelmorphunet + :members: + +.. autoclass:: VoxelMorph + :members: + +.. autoclass:: voxelmorph Utilities --------- diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 1fb0f08ccc..9247aaee85 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -111,3 +111,4 @@ from .vit import ViT from .vitautoenc import ViTAutoEnc from .vnet import VNet +from .voxelmorph import VoxelMorph, VoxelMorphUNet diff --git a/monai/networks/nets/voxelmorph.py b/monai/networks/nets/voxelmorph.py new file mode 100644 index 0000000000..0496cfc8f8 --- /dev/null +++ b/monai/networks/nets/voxelmorph.py @@ -0,0 +1,481 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks.convolutions import Convolution +from monai.networks.blocks.upsample import UpSample +from monai.networks.blocks.warp import DVF2DDF, Warp +from monai.networks.layers.simplelayers import SkipConnection +from monai.utils import alias, export + +__all__ = ["VoxelMorphUNet", "voxelmorphunet", "VoxelMorph", "voxelmorph"] + + +@export("monai.networks.nets") +@alias("voxelmorphunet") +class VoxelMorphUNet(nn.Module): + """ + The backbone network used in VoxelMorph. See :py:class:`monai.networks.nets.VoxelMorph` for more details. + + A concatenated pair of images (moving and fixed) is first passed through a UNet. The output of the UNet is then + passed through a series of convolution blocks to produce the final prediction of the displacement field (DDF) or the + stationary velocity field (DVF). + + In the original implementation, downsample is achieved through maxpooling, here one has the option to use either + maxpooling or strided convolution for downsampling. The default is to use maxpooling as it is consistent with the + original implementation. Note that for upsampling, the authors of VoxelMorph used nearest neighbor interpolation + instead of transposed convolution. In this implementation, only nearest neighbor interpolation is supported in order + to be consistent with the original implementation. + + An instance of this class can be used as a backbone network for constructing a VoxelMorph network. See the + documentation of :py:class:`monai.networks.nets.VoxelMorph` for more details and an example on how to construct a + VoxelMorph network. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of channels in the input volume after concatenation of moving and fixed images. + unet_out_channels: number of channels in the output of the UNet. + channels: number of channels in each layer of the UNet. See the following example for more details. + final_conv_channels: number of channels in each layer of the final convolution block. + final_conv_act: activation type for the final convolution block. Defaults to LeakyReLU. + Since VoxelMorph was originally implemented in tensorflow where the default negative slope for + LeakyReLU was 0.2, we use the same default value here. + kernel_size: kernel size for all convolution layers in the UNet. Defaults to 3. + up_kernel_size: kernel size for all convolution layers in the upsampling path of the UNet. Defaults to 3. + act: activation type for all convolution layers in the UNet. Defaults to LeakyReLU with negative slope 0.2. + norm: feature normalization type and arguments for all convolution layers in the UNet. Defaults to None. + dropout: dropout ratio for all convolution layers in the UNet. Defaults to 0.0 (no dropout). + bias: whether to use bias in all convolution layers in the UNet. Defaults to True. + use_maxpool: whether to use maxpooling in the downsampling path of the UNet. Defaults to True. + Using maxpooling is the consistent with the original implementation of VoxelMorph. + But one can optionally use strided convolution instead (i.e. set `use_maxpool` to False). + adn_ordering: ordering of activation, dropout, and normalization. Defaults to "NDA". + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + unet_out_channels: int, + channels: Sequence[int], + final_conv_channels: Sequence[int], + final_conv_act: tuple | str | None = "LEAKYRELU", + kernel_size: Sequence[int] | int = 3, + up_kernel_size: Sequence[int] | int = 3, + act: tuple | str = "LEAKYRELU", + norm: tuple | str | None = None, + dropout: float = 0.0, + bias: bool = True, + use_maxpool: bool = True, + adn_ordering: str = "NDA", + ) -> None: + super().__init__() + + if spatial_dims not in (2, 3): + raise ValueError("spatial_dims must be either 2 or 3.") + if in_channels % 2 != 0: + raise ValueError("in_channels must be divisible by 2.") + if len(channels) < 2: + raise ValueError("the length of `channels` should be no less than 2.") + if len(channels) % 2 != 0: + raise ValueError("the elements of `channels` should be specified in pairs.") + if isinstance(kernel_size, Sequence) and len(kernel_size) != spatial_dims: + raise ValueError("the length of `kernel_size` should equal to `dimensions`.") + if isinstance(up_kernel_size, Sequence) and len(up_kernel_size) != spatial_dims: + raise ValueError("the length of `up_kernel_size` should equal to `dimensions`.") + + # UNet args + self.dimensions = spatial_dims + self.in_channels = in_channels + self.unet_out_channels = unet_out_channels + self.channels = channels + self.kernel_size = kernel_size + self.up_kernel_size = up_kernel_size + self.act = ( + ("leakyrelu", {"negative_slope": 0.2, "inplace": True}) + if isinstance(act, str) and act.upper() == "LEAKYRELU" + else act + ) + self.norm = norm + self.dropout = dropout + self.bias = bias + self.use_maxpool = use_maxpool + self.adn_ordering = adn_ordering + + # final convolutions args + self.final_conv_channels = final_conv_channels + self.final_conv_act = ( + ("leakyrelu", {"negative_slope": 0.2, "inplace": True}) + if isinstance(final_conv_act, str) and final_conv_act.upper() == "LEAKYRELU" + else final_conv_act + ) + + def _create_block(inc: int, outc: int, channels: Sequence[int], is_top: bool) -> nn.Module: + """ + Builds the UNet structure recursively. + + Args: + inc: number of input channels. + outc: number of output channels. + channels: sequence of channels for each pair of down and up layers. + is_top: True if this is the top block. + """ + + next_c_in, next_c_out = channels[0:2] + upc = next_c_in + next_c_out + + subblock: nn.Module + + if len(channels) > 2: + subblock = _create_block(next_c_in, next_c_out, channels[2:], is_top=False) # continue recursion down + else: + # the next layer is the bottom so stop recursion, create the bottom layer as the sublock for this layer + subblock = self._get_bottom_layer(next_c_in, next_c_out) + + down = self._get_down_layer(inc, next_c_in, is_top) # create layer in downsampling path + up = self._get_up_layer(upc, outc, is_top) # create layer in upsampling path + + return self._get_connection_block(down, up, subblock) + + def _create_final_conv(inc: int, outc: int, channels: Sequence[int]) -> nn.Module: + """ + Builds the final convolution blocks. + + Args: + inc: number of input channels, should be the same as `unet_out_channels`. + outc: number of output channels, should be the same as `spatial_dims`. + channels: sequence of channels for each convolution layer. + + Note: there is no activation after the last convolution layer as per the original implementation. + """ + + mod: nn.Module = nn.Sequential() + + for i, c in enumerate(channels): + mod.add_module( + f"final_conv_{i}", + Convolution( + self.dimensions, + inc, + c, + kernel_size=self.kernel_size, + act=self.final_conv_act, + norm=self.norm, + dropout=self.dropout, + bias=self.bias, + adn_ordering=self.adn_ordering, + ), + ) + inc = c + + mod.add_module( + "final_conv_out", + Convolution( + self.dimensions, + inc, + outc, + kernel_size=self.kernel_size, + act=None, + norm=self.norm, + dropout=self.dropout, + bias=self.bias, + adn_ordering=self.adn_ordering, + ), + ) + + return mod + + self.net = nn.Sequential( + _create_block(in_channels, unet_out_channels, self.channels, is_top=True), + _create_final_conv(unet_out_channels, self.dimensions, self.final_conv_channels), + ) + + def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: + """ + Returns the block object defining a layer of the UNet structure including the implementation of the skip + between encoding (down) and decoding (up) sides of the network. + + Args: + down_path: encoding half of the layer + up_path: decoding half of the layer + subblock: block defining the next layer in the network. + + Returns: block for this layer: `nn.Sequential(down_path, SkipConnection(subblock), up_path)` + """ + + return nn.Sequential(down_path, SkipConnection(subblock), up_path) + + def _get_down_layer(self, in_channels: int, out_channels: int, is_top: bool) -> nn.Module: + """ + In each down layer, the input is first downsampled using maxpooling, + then passed through a convolution block, unless this is the top layer + in which case the input is passed through a convolution block only + without maxpooling first. + + Args: + in_channels: number of input channels. + out_channels: number of output channels. + is_top: True if this is the top block. + """ + + mod: Convolution | nn.Sequential + + strides = 1 if self.use_maxpool or is_top else 2 + + mod = Convolution( + self.dimensions, + in_channels, + out_channels, + strides=strides, + kernel_size=self.kernel_size, + act=self.act, + norm=self.norm, + dropout=self.dropout, + bias=self.bias, + adn_ordering=self.adn_ordering, + ) + + if self.use_maxpool and not is_top: + mod = ( + nn.Sequential(nn.MaxPool3d(kernel_size=2, stride=2), mod) + if self.dimensions == 3 + else nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2), mod) + ) + + return mod + + def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module: + """ + Bottom layer (bottleneck) in voxelmorph consists of a typical down layer followed by an upsample layer. + + Args: + in_channels: number of input channels. + out_channels: number of output channels. + """ + + mod: nn.Module + upsample: nn.Module + + mod = self._get_down_layer(in_channels, out_channels, is_top=False) + + upsample = UpSample( + self.dimensions, + out_channels, + out_channels, + scale_factor=2, + mode="nontrainable", + interp_mode="nearest", + align_corners=None, # required to use with interp_mode="nearest" + ) + + return nn.Sequential(mod, upsample) + + def _get_up_layer(self, in_channels: int, out_channels: int, is_top: bool) -> nn.Module: + """ + In each up layer, the input is passed through a convolution block before upsampled, + unless this is the top layer in which case the input is passed through a convolution block only + without upsampling. + + Args: + in_channels: number of input channels. + out_channels: number of output channels. + is_top: True if this is the top block. + """ + + mod: Convolution | nn.Sequential + + strides = 1 + + mod = Convolution( + self.dimensions, + in_channels, + out_channels, + strides=strides, + kernel_size=self.up_kernel_size, + act=self.act, + norm=self.norm, + dropout=self.dropout, + bias=self.bias, + # conv_only=is_top, + is_transposed=False, + adn_ordering=self.adn_ordering, + ) + + if not is_top: + mod = nn.Sequential( + mod, + UpSample( + self.dimensions, + out_channels, + out_channels, + scale_factor=2, + mode="nontrainable", + interp_mode="nearest", + align_corners=None, # required to use with interp_mode="nearest" + ), + ) + + return mod + + def forward(self, concatenated_pairs: torch.Tensor) -> torch.Tensor: + x = self.net(concatenated_pairs) + return x # type: ignore + + +voxelmorphunet = VoxelMorphUNet + + +@export("monai.networks.nets") +@alias("voxelmorph") +class VoxelMorph(nn.Module): + """ + A re-implementation of VoxelMorph framework for medical image registration as described in + https://arxiv.org/pdf/1809.05231.pdf. For more details, please refer to VoxelMorph: A Learning Framework for + Deformable Medical Image Registration, Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca + IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231. + + This class is intended to be a general framework, based on which a deformable image registration + network can be built. Given a user-specified backbone network (e.g., UNet in the original VoxelMorph paper), this + class serves as a wrapper that concatenates the input pair of moving and fixed images, passes through the backbone + network, integrate the predicted stationary velocity field (DVF) from the backbone network to obtain the + displacement field (DDF), and, finally, warp the moving image using the DDF. + + To construct a VoxelMorph network, one need to first construct a backbone network + (e.g., a :py:class:`monai.networks.nets.VoxelMorphUNet`) and pass it to the constructor of + :py:class:`monai.networks.nets.VoxelMorph`. The backbone network should be able to take a pair of moving and fixed + images as input and produce a DVF (or DDF, details to be discussed later) as output. + + When `forward` is called, the input moving and fixed images are first concatenated along the channel dimension and + passed through the specified backbone network to produce the prediction of the displacement field (DDF) in the + non-diffeomorphic variant (i.e. when `integration_steps` is set to 0) or the stationary velocity field (DVF) in the + diffeomorphic variant (i.e. when `integration_steps` is set to a positive integer). The DVF is then integrated using + a scaling-and-squaring approach via a :py:class:`monai.networks.blocks.warp.DVF2DDF` module to produce the DDF. + Finally, the DDF is used to warp the moving image to the fixed image using a + :py:class:`monai.networks.blocks.warp.Warp` module. Optionally, the integration from DVF to DDF can be + performed on reduced resolution by specifying `half_res` to be True, in which case the output DVF from the backbone + network is first linearly interpolated to half resolution before integration. The output DDF is then linearly + interpolated again back to full resolution before being used to warp the moving image. + + Args: + backbone: a backbone network. + integration_steps: number of integration steps used for obtaining DDF from DVF via scaling-and-squaring. + Defaults to 7. If set to 0, the network will be non-diffeomorphic. + half_res: whether to perform integration on half resolution. Defaults to False. + spatial_dims: number of spatial dimensions, defaults to 3. + + Example:: + + from monai.networks.nets import VoxelMorphUNet, VoxelMorph + + # The following example construct an instance of VoxelMorph that matches the original VoxelMorph paper + # https://arxiv.org/pdf/1809.05231.pdf + + # First, a backbone network is constructed. In this case, we use a VoxelMorphUNet as the backbone network. + backbone = VoxelMorphUNet( + spatial_dims=3, + in_channels=2, + unet_out_channels=32, + channels=(16, 32, 32, 32, 32, 32), # this indicates the down block at the top takes 16 channels as + # input, the corresponding up block at the top produces 32 + # channels as output, the second down block takes 32 channels as + # input, and the corresponding up block at the same level + # produces 32 channels as output, etc. + final_conv_channels=(16, 16) + ) + + # Then, a full VoxelMorph network is constructed using the specified backbone network. + net = VoxelMorph( + backbone=backbone, + integration_steps=7, + half_res=False + ) + + # A forward pass through the network would look something like this + moving = torch.randn(1, 1, 160, 192, 224) + fixed = torch.randn(1, 1, 160, 192, 224) + warped, ddf = net(moving, fixed) + + """ + + def __init__( + self, + backbone: VoxelMorphUNet | nn.Module | None = None, + integration_steps: int = 7, + half_res: bool = False, + spatial_dims: int = 3, + ) -> None: + super().__init__() + + # specified backbone network + self.backbone = ( + backbone + if backbone is not None + else VoxelMorphUNet( + spatial_dims=spatial_dims, + in_channels=2, + unet_out_channels=32, + channels=(16, 32, 32, 32, 32, 32), + final_conv_channels=(16, 16), + ) + ) + + # helper attributes + self.spatial_dims = spatial_dims + self.half_res = half_res + self.integration_steps = integration_steps + self.diffeomorphic = True if self.integration_steps > 0 else False + + # create helpers + if self.diffeomorphic: + self.dvf2ddf = DVF2DDF(num_steps=self.integration_steps, mode="bilinear", padding_mode="zeros") + self.warp = Warp(mode="bilinear", padding_mode="zeros") + + def forward(self, moving: torch.Tensor, fixed: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + if moving.shape != fixed.shape: + raise ValueError( + "The spatial shape of the moving image should be the same as the spatial shape of the fixed image." + f" Got {moving.shape} and {fixed.shape} instead." + ) + + x = self.backbone(torch.cat([moving, fixed], dim=1)) + + if x.shape[1] != self.spatial_dims: + raise ValueError( + "The number of channels in the output of the backbone network should be equal to the" + f" number of spatial dimensions {self.spatial_dims}. Got {x.shape[1]} channels instead." + ) + + if x.shape[2:] != moving.shape[2:]: + raise ValueError( + "The spatial shape of the output of the backbone network should be equal to the" + f" spatial shape of the input images. Got {x.shape[2:]} instead of {moving.shape[2:]}." + ) + + if self.half_res: + x = F.interpolate(x, scale_factor=0.5, mode="trilinear", align_corners=True) * 2.0 + + if self.diffeomorphic: + x = self.dvf2ddf(x) + + if self.half_res: + x = F.interpolate(x * 0.5, scale_factor=2.0, mode="trilinear", align_corners=True) + + return self.warp(moving, x), x + + +voxelmorph = VoxelMorph diff --git a/tests/test_voxelmorph.py b/tests/test_voxelmorph.py new file mode 100644 index 0000000000..c51f70cbf5 --- /dev/null +++ b/tests/test_voxelmorph.py @@ -0,0 +1,281 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import VoxelMorph, VoxelMorphUNet +from tests.utils import test_script_save + +device = "cuda" if torch.cuda.is_available() else "cpu" + +TEST_CASE_0 = [ # single channel 3D, batch 1, + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + }, + (1, 2, 96, 96, 48), + (1, 3, 96, 96, 48), +] + +TEST_CASE_1 = [ # single channel 3D, batch 1, + # using strided convolution for downsampling instead of maxpooling + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + "use_maxpool": False, + }, + (1, 2, 96, 96, 48), + (1, 3, 96, 96, 48), +] + +TEST_CASE_2 = [ # single channel 3D, batch 1, + # using strided convolution for downsampling instead of maxpooling, + # explicitly specify leakyrelu with a different negative slope for final convolutions + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + "final_conv_act": ("leakyrelu", {"negative_slope": 0.1, "inplace": True}), + "use_maxpool": False, + }, + (1, 2, 96, 96, 48), + (1, 3, 96, 96, 48), +] + +TEST_CASE_3 = [ # single channel 3D, batch 1, + # using strided convolution for downsampling instead of maxpooling, + # explicitly specify leakyrelu with a different negative slope for both unet and final convolutions. + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + "final_conv_act": ("leakyrelu", {"negative_slope": 0.1, "inplace": True}), + "act": ("leakyrelu", {"negative_slope": 0.1, "inplace": True}), + "use_maxpool": False, + }, + (1, 2, 96, 96, 48), + (1, 3, 96, 96, 48), +] + +TEST_CASE_4 = [ # 2-channel 3D, batch 1, + # i.e., possible use case where the input contains both modalities (e.g., T1 and T2) + { + "spatial_dims": 3, + "in_channels": 4, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + }, + (1, 4, 96, 96, 48), + (1, 3, 96, 96, 48), +] + +TEST_CASE_5 = [ # single channel 3D, batch 2, + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + }, + (2, 2, 96, 96, 48), + (2, 3, 96, 96, 48), +] + +TEST_CASE_6 = [ # single channel 2D, batch 2, + { + "spatial_dims": 2, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + }, + (2, 2, 96, 96), + (2, 2, 96, 96), +] + +TEST_CASE_7 = [ # single channel 3D, batch 1, + # one additional level in the UNet with 32 channels in both down and up branch. + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + }, + (1, 2, 96, 96, 48), + (1, 3, 96, 96, 48), +] + +TEST_CASE_8 = [ # single channel 3D, batch 1, + # one additional level in the UNet with 32 channels in both down and up branch. + # and removed one of the two final convolution blocks. + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32, 32, 32), + "final_conv_channels": (16,), + }, + (1, 2, 96, 96, 48), + (1, 3, 96, 96, 48), +] + +TEST_CASE_9 = [ # single channel 3D, batch 1, + # only one level in the UNet + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32), + "final_conv_channels": (16, 16), + }, + (1, 2, 96, 96, 48), + (1, 3, 96, 96, 48), +] + +CASES = [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, +] + +ILL_CASE_0 = [ # spatial_dims = 1 + { + "spatial_dims": 1, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + } +] + +ILL_CASE_1 = [ # in_channels = 3 (not divisible by 2) + { + "spatial_dims": 3, + "in_channels": 3, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + } +] + +ILL_CASE_2 = [ # len(channels) = 0 + {"spatial_dims": 3, "in_channels": 2, "unet_out_channels": 32, "channels": (), "final_conv_channels": (16, 16)} +] + +ILL_CASE_3 = [ # channels not in pairs + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + } +] + +ILL_CASE_4 = [ # len(kernel_size) = 3, spatial_dims = 2 + { + "spatial_dims": 2, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + "kernel_size": (3, 3, 3), + } +] + +ILL_CASE_5 = [ # len(up_kernel_size) = 2, spatial_dims = 3 + { + "spatial_dims": 3, + "in_channels": 2, + "unet_out_channels": 32, + "channels": (16, 32, 32, 32, 32, 32), + "final_conv_channels": (16, 16), + "up_kernel_size": (3, 3), + } +] + +ILL_CASES = [ILL_CASE_0, ILL_CASE_1, ILL_CASE_2, ILL_CASE_3, ILL_CASE_4, ILL_CASE_5] + + +ILL_CASES_IN_SHAPE_0 = [ # moving and fixed image shape not match + {"spatial_dims": 3}, + (1, 2, 96, 96, 48), + (1, 3, 96, 96, 48), +] + +ILL_CASES_IN_SHAPE_1 = [ # spatial_dims = 2, ddf has 3 channels + {"spatial_dims": 2}, + (1, 1, 96, 96, 96), + (1, 1, 96, 96, 96), +] + +ILL_CASES_IN_SHAPE = [ILL_CASES_IN_SHAPE_0, ILL_CASES_IN_SHAPE_1] + + +class TestVOXELMORPH(unittest.TestCase): + @parameterized.expand(CASES) + def test_shape(self, input_param, input_shape, expected_shape): + net = VoxelMorphUNet(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_script(self): + net = VoxelMorphUNet( + spatial_dims=2, + in_channels=2, + unet_out_channels=32, + channels=(16, 32, 32, 32, 32, 32), + final_conv_channels=(16, 16), + ) + test_data = torch.randn(1, 2, 96, 96) + test_script_save(net, test_data) + + @parameterized.expand(ILL_CASES) + def test_ill_input_hyper_params(self, input_param): + with self.assertRaises(ValueError): + _ = VoxelMorphUNet(**input_param) + + @parameterized.expand(ILL_CASES_IN_SHAPE) + def test_ill_input_shape(self, input_param, moving_shape, fixed_shape): + with self.assertRaises((ValueError, RuntimeError)): + net = VoxelMorph(**input_param).to(device) + with eval_mode(net): + _ = net.forward(torch.randn(moving_shape).to(device), torch.randn(fixed_shape).to(device)) + + +if __name__ == "__main__": + unittest.main() From 016b386ed9c14e07b02652d1c233e4160ec38977 Mon Sep 17 00:00:00 2001 From: monai-bot <64792179+monai-bot@users.noreply.github.com> Date: Mon, 6 Nov 2023 08:57:08 +0000 Subject: [PATCH 07/67] auto updates (#7203) Signed-off-by: monai-bot Signed-off-by: monai-bot Signed-off-by: Mark Graham --- tests/test_squeeze_unsqueeze.py | 1 - tests/test_voxelmorph.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tests/test_squeeze_unsqueeze.py b/tests/test_squeeze_unsqueeze.py index 2db26a6bdc..130a214345 100644 --- a/tests/test_squeeze_unsqueeze.py +++ b/tests/test_squeeze_unsqueeze.py @@ -28,7 +28,6 @@ (torch.rand(3, 4).type(torch.int32), 5, (3, 4, 1, 1, 1)), ] - LEFT_CASES = [ (np.random.rand(3, 4).astype(np.float32), 5, (1, 1, 1, 3, 4)), (torch.rand(3, 4).type(torch.float32), 5, (1, 1, 1, 3, 4)), diff --git a/tests/test_voxelmorph.py b/tests/test_voxelmorph.py index c51f70cbf5..53ef2fc18f 100644 --- a/tests/test_voxelmorph.py +++ b/tests/test_voxelmorph.py @@ -229,7 +229,6 @@ ILL_CASES = [ILL_CASE_0, ILL_CASE_1, ILL_CASE_2, ILL_CASE_3, ILL_CASE_4, ILL_CASE_5] - ILL_CASES_IN_SHAPE_0 = [ # moving and fixed image shape not match {"spatial_dims": 3}, (1, 2, 96, 96, 48), From a2c671a3bffcd9646b453487a42fdd77c060ec65 Mon Sep 17 00:00:00 2001 From: elitap Date: Wed, 15 Nov 2023 12:00:21 +0100 Subject: [PATCH 08/67] fix useless error msg in nnunetv2runner (#7217) Fixes # fix useless error msg in nnunetv2runner ### Description ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). Signed-off-by: elitap Signed-off-by: Mark Graham --- monai/apps/nnunet/nnunetv2_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/nnunet/nnunetv2_runner.py b/monai/apps/nnunet/nnunetv2_runner.py index a3357cd9b3..e62809403e 100644 --- a/monai/apps/nnunet/nnunetv2_runner.py +++ b/monai/apps/nnunet/nnunetv2_runner.py @@ -275,8 +275,8 @@ def convert_dataset(self): num_input_channels=num_input_channels, output_datafolder=raw_data_foldername, ) - except BaseException: - logger.warning("Input config may be incorrect. Detail info: error/exception message is:\n {err}") + except BaseException as err: + logger.warning(f"Input config may be incorrect. Detail info: error/exception message is:\n {err}") return def convert_msd_dataset(self, data_dir: str, overwrite_id: str | None = None, n_proc: int = -1) -> None: From 4bad4bcaaeaf400b357b818ae195f23091441a0d Mon Sep 17 00:00:00 2001 From: Felix Schnabel Date: Thu, 16 Nov 2023 19:49:14 +0100 Subject: [PATCH 09/67] Fixup mypy 1.7.0 errors (#7231) Fixes #7230. ### Description Fix the typing issues and the deprecation. Also always run type checking with Linux environment, since ForkServerContext is not available on Windows. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). --------- Signed-off-by: Felix Schnabel Signed-off-by: Mark Graham --- .gitignore | 1 + monai/apps/auto3dseg/data_analyzer.py | 2 +- monai/apps/deepgrow/interaction.py | 2 +- monai/apps/pathology/metrics/lesion_froc.py | 5 ++++- monai/metrics/utils.py | 4 ++-- monai/transforms/io/array.py | 2 +- requirements-dev.txt | 2 +- setup.cfg | 2 +- 8 files changed, 12 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 8c66d4a651..437677d2bb 100644 --- a/.gitignore +++ b/.gitignore @@ -115,6 +115,7 @@ venv.bak/ examples/scd_lvsegs.npz temp/ .idea/ +.dmypy.json *~ diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py index 2c485f03eb..9280fb5be5 100644 --- a/monai/apps/auto3dseg/data_analyzer.py +++ b/monai/apps/auto3dseg/data_analyzer.py @@ -210,7 +210,7 @@ def get_all_case_stats(self, key="training", transform_list=None): nprocs = torch.cuda.device_count() logger.info(f"Found {nprocs} GPUs for data analyzing!") if nprocs > 1: - tmp_ctx = get_context("forkserver") + tmp_ctx: Any = get_context("forkserver") with tmp_ctx.Manager() as manager: manager_list = manager.list() processes = [] diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py index 88211c31e3..fa3a28bfef 100644 --- a/monai/apps/deepgrow/interaction.py +++ b/monai/apps/deepgrow/interaction.py @@ -49,7 +49,7 @@ def __init__( if not isinstance(transforms, Compose): transforms = Compose(transforms) - self.transforms = transforms + self.transforms: Compose = transforms self.max_interactions = max_interactions self.train = train self.key_probability = key_probability diff --git a/monai/apps/pathology/metrics/lesion_froc.py b/monai/apps/pathology/metrics/lesion_froc.py index 0899de9a76..f4bf51ab28 100644 --- a/monai/apps/pathology/metrics/lesion_froc.py +++ b/monai/apps/pathology/metrics/lesion_froc.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Iterable import numpy as np @@ -94,6 +94,9 @@ def prepare_inference_result(self, sample: dict) -> tuple[np.ndarray, np.ndarray nms_outputs = self.nms(probs_map=prob_map, resolution_level=sample["level"]) # separate nms outputs + probs: Iterable[Any] + x_coord: Iterable[Any] + y_coord: Iterable[Any] if nms_outputs: probs, x_coord, y_coord = zip(*nms_outputs) else: diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index c139fc35ed..4d4e6570c5 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -14,7 +14,7 @@ import warnings from functools import lru_cache, partial from types import ModuleType -from typing import Any, Sequence +from typing import Any, Iterable, Sequence import numpy as np import torch @@ -383,7 +383,7 @@ def remap_instance_id(pred: torch.Tensor, by_size: bool = False) -> torch.Tensor by_size: if True, largest instance will be assigned a smaller id. """ - pred_id = list(pred.unique()) + pred_id: Iterable[Any] = list(pred.unique()) # the original implementation has the limitation that if there is no 0 in pred, error will happen pred_id = [i for i in pred_id if i != 0] diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index b36c011822..cd7e4ef090 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -268,7 +268,7 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader break if img is None or reader is None: - if isinstance(filename, tuple) and len(filename) == 1: + if isinstance(filename, Sequence) and len(filename) == 1: filename = filename[0] msg = "\n".join([f"{e}" for e in err]) raise RuntimeError( diff --git a/requirements-dev.txt b/requirements-dev.txt index 38715b8449..6332d5b0a5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -22,7 +22,7 @@ isort>=5.1 ruff pytype>=2020.6.1; platform_system != "Windows" types-pkg_resources -mypy>=0.790 +mypy>=1.5.0 ninja torchvision psutil diff --git a/setup.cfg b/setup.cfg index d6c9b4f190..123da68dfa 100644 --- a/setup.cfg +++ b/setup.cfg @@ -228,7 +228,7 @@ pretty = False # Warns about per-module sections in the config file that do not match any files processed when invoking mypy. warn_unused_configs = True # Make arguments prepended via Concatenate be truly positional-only. -strict_concatenate = True +extra_checks = True # Allows variables to be redefined with an arbitrary type, # as long as the redefinition is in the same block and nesting level as the original definition. # allow_redefinition = True From b9b57c8c57df10769ef1f0f9f703be689adbb54d Mon Sep 17 00:00:00 2001 From: YanxuanLiu <104543031+YanxuanLiu@users.noreply.github.com> Date: Fri, 17 Nov 2023 11:21:56 +0800 Subject: [PATCH 10/67] add Yun Liu to user list to trigger blossom-ci [skip ci] (#7239) Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Add new developer to blossom-ci trigger list Signed-off-by: YanxuanLiu Signed-off-by: Mark Graham --- .github/workflows/blossom-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/blossom-ci.yml b/.github/workflows/blossom-ci.yml index 5e4612c6c5..1d6ee8a46c 100644 --- a/.github/workflows/blossom-ci.yml +++ b/.github/workflows/blossom-ci.yml @@ -34,6 +34,7 @@ jobs: wyli,\ pxLi,\ YanxuanLiu,\ + KumoLiu,\ ', format('{0},', github.actor)) && github.event.comment.body == '/build' steps: - name: Check if comment is issued by authorized person From a22b5dd616b6344cdba798659e1293d0d9a6d230 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 17 Nov 2023 16:18:07 +0800 Subject: [PATCH 11/67] =?UTF-8?q?Replace=20single=20quotation=20marks=20wi?= =?UTF-8?q?th=20double=20quotation=20marks=20to=20install=E2=80=A6=20(#723?= =?UTF-8?q?4)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … MONAI with all dependencies on Windows Fixes #6118 ### Description The Windows shell doesn't recognize single quotes to delimit a string at all, so on Windows you'll need to use double quotes. It's the same command no matter which type of quotes you use; after the shell does its processing, the argument is passed to pip with the quotation marks already removed. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: ytl0623 Signed-off-by: Mark Graham --- docs/source/installation.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/installation.md b/docs/source/installation.md index 88107c9487..d77253f0f9 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -228,7 +228,7 @@ Alternatively, to install all optional dependencies: ```bash git clone https://github.com/Project-MONAI/MONAI.git cd MONAI/ -pip install -e '.[all]' +pip install -e ".[all]" ``` To install all optional dependencies with `pip` based on MONAI development environment settings: From 8f781eb70e3ec8d9fa3659439d1c282f0e749dff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C5=BEenan=20Zuki=C4=87?= Date: Fri, 17 Nov 2023 04:56:52 -0500 Subject: [PATCH 12/67] Update bug_report.md (#7213) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Avoid syntax error on Windows ### Description ```log (.venv) C:\Dev\Kitware\python>python -c 'import monai; monai.config.print_debug_info()' File "", line 1 'import ^ SyntaxError: unterminated string literal (detected at line 1) ``` ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Dženan Zukić Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Mark Graham --- .github/ISSUE_TEMPLATE/bug_report.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index cebcdfc917..4ae6c07732 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -27,7 +27,7 @@ If applicable, add screenshots to help explain your problem. Ensuring you use the relevant python executable, please paste the output of: ``` -python -c 'import monai; monai.config.print_debug_info()' +python -c "import monai; monai.config.print_debug_info()" ``` **Additional context** From bf166b574ec3f8c0a86496bf7e433ad5897e50a6 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 17 Nov 2023 22:14:17 +0800 Subject: [PATCH 13/67] Add cache option in `GridPatchDataset` (#7180) Part of #6904 ### Description - Fix inefficient patching in `PatchDataset` - Add cache option in `GridPatchDataset` ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/data/grid_dataset.py | 218 ++++++++++++++++++++++++++++++------ tests/test_grid_dataset.py | 55 +++++++-- tests/test_patch_dataset.py | 15 ++- 3 files changed, 242 insertions(+), 46 deletions(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 06954e9f11..9079032e6f 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -11,18 +11,30 @@ from __future__ import annotations -from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence +import sys +import warnings +from collections.abc import Callable, Generator, Hashable, Iterable, Iterator, Mapping, Sequence from copy import deepcopy +from multiprocessing.managers import ListProxy +from multiprocessing.pool import ThreadPool +from typing import TYPE_CHECKING import numpy as np +import torch from monai.config import KeysCollection from monai.config.type_definitions import NdarrayTensor -from monai.data.dataset import Dataset from monai.data.iterable_dataset import IterableDataset -from monai.data.utils import iter_patch -from monai.transforms import apply_transform -from monai.utils import NumpyPadMode, ensure_tuple, first +from monai.data.utils import iter_patch, pickle_hashing +from monai.transforms import Compose, RandomizableTrait, Transform, apply_transform, convert_to_contiguous +from monai.utils import NumpyPadMode, ensure_tuple, first, min_version, optional_import + +if TYPE_CHECKING: + from tqdm import tqdm + + has_tqdm = True +else: + tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm") __all__ = ["PatchDataset", "GridPatchDataset", "PatchIter", "PatchIterd"] @@ -184,6 +196,25 @@ class GridPatchDataset(IterableDataset): see also: :py:class:`monai.data.PatchIter` or :py:class:`monai.data.PatchIterd`. transform: a callable data transform operates on the patches. with_coordinates: whether to yield the coordinates of each patch, default to `True`. + cache: whether to use cache mache mechanism, default to `False`. + see also: :py:class:`monai.data.CacheDataset`. + cache_num: number of items to be cached. Default is `sys.maxsize`. + will take the minimum of (cache_num, data_length x cache_rate, data_length). + cache_rate: percentage of cached data in total, default is 1.0 (cache all). + will take the minimum of (cache_num, data_length x cache_rate, data_length). + num_workers: the number of worker threads if computing cache in the initialization. + If num_workers is None then the number returned by os.cpu_count() is used. + If a value less than 1 is specified, 1 will be used instead. + progress: whether to display a progress bar. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cached content + (for example, randomly crop from the cached image and deepcopy the crop region) + or if every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. + as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous. + it may help improve the performance of following logic. + hash_func: a callable to compute hash from data items to be cached. + defaults to `monai.data.utils.pickle_hashing`. """ @@ -193,27 +224,148 @@ def __init__( patch_iter: Callable, transform: Callable | None = None, with_coordinates: bool = True, + cache: bool = False, + cache_num: int = sys.maxsize, + cache_rate: float = 1.0, + num_workers: int | None = 1, + progress: bool = True, + copy_cache: bool = True, + as_contiguous: bool = True, + hash_func: Callable[..., bytes] = pickle_hashing, ) -> None: super().__init__(data=data, transform=None) + if transform is not None and not isinstance(transform, Compose): + transform = Compose(transform) self.patch_iter = patch_iter self.patch_transform = transform self.with_coordinates = with_coordinates + self.set_num = cache_num + self.set_rate = cache_rate + self.progress = progress + self.copy_cache = copy_cache + self.as_contiguous = as_contiguous + self.hash_func = hash_func + self.num_workers = num_workers + if self.num_workers is not None: + self.num_workers = max(int(self.num_workers), 1) + self._cache: list | ListProxy = [] + self._cache_other: list | ListProxy = [] + self.cache = cache + self.first_random: int | None = None + if self.patch_transform is not None: + self.first_random = self.patch_transform.get_index_of_first( + lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) + ) - def __iter__(self): - for image in super().__iter__(): - for patch, *others in self.patch_iter(image): - out_patch = patch - if self.patch_transform is not None: - out_patch = apply_transform(self.patch_transform, patch, map_items=False) - if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords - yield out_patch, others[0] - else: - yield out_patch + if self.cache: + if isinstance(data, Iterator): + raise TypeError("Data can not be iterator when cache is True") + self.set_data(data) # type: ignore + + def set_data(self, data: Sequence) -> None: + """ + Set the input data and run deterministic transforms to generate cache content. + + Note: should call this func after an entire epoch and must set `persistent_workers=False` + in PyTorch DataLoader, because it needs to create new worker processes based on new + generated cache content. + + """ + self.data = data + + # only compute cache for the unique items of dataset, and record the last index for duplicated items + mapping = {self.hash_func(v): i for i, v in enumerate(self.data)} + self.cache_num = min(int(self.set_num), int(len(mapping) * self.set_rate), len(mapping)) + self._hash_keys = list(mapping)[: self.cache_num] + indices = list(mapping.values())[: self.cache_num] + self._cache, self._cache_other = zip(*self._fill_cache(indices)) # type: ignore + + def _fill_cache(self, indices=None) -> list: + """ + Compute and fill the cache content from data source. + + Args: + indices: target indices in the `self.data` source to compute cache. + if None, use the first `cache_num` items. + + """ + if self.cache_num <= 0: + return [] + if indices is None: + indices = list(range(self.cache_num)) + if self.progress and not has_tqdm: + warnings.warn("tqdm is not installed, will not show the caching progress bar.") + + pfunc = tqdm if self.progress and has_tqdm else (lambda v, **_: v) + with ThreadPool(self.num_workers) as p: + return list(pfunc(p.imap(self._load_cache_item, indices), total=len(indices), desc="Loading dataset")) + + def _load_cache_item(self, idx: int): + """ + Args: + idx: the index of the input data sequence. + """ + item = self.data[idx] # type: ignore + patch_cache, other_cache = [], [] + for patch, *others in self.patch_iter(item): + if self.first_random is not None: + patch = self.patch_transform(patch, end=self.first_random, threading=True) # type: ignore + + if self.as_contiguous: + patch = convert_to_contiguous(patch, memory_format=torch.contiguous_format) + if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords + other_cache.append(others[0]) + patch_cache.append(patch) + return patch_cache, other_cache + + def _generate_patches(self, src, **apply_args): + """ + yield patches optionally post-processed by transform. + Args: + src: a iterable of image patches. + apply_args: other args for `self.patch_transform`. + + """ + for patch, *others in src: + out_patch = patch + if self.patch_transform is not None: + out_patch = self.patch_transform(patch, **apply_args) + if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords + yield out_patch, others[0] + else: + yield out_patch -class PatchDataset(Dataset): + def __iter__(self): + if self.cache: + cache_index = None + for image in super().__iter__(): + key = self.hash_func(image) + if key in self._hash_keys: + # if existing in cache, try to get the index in cache + cache_index = self._hash_keys.index(key) + if cache_index is None: + # no cache for this index, execute all the transforms directly + yield from self._generate_patches(self.patch_iter(image)) + else: + if self._cache is None: + raise RuntimeError( + "Cache buffer is not initialized, please call `set_data()` before epoch begins." + ) + data = self._cache[cache_index] # type: ignore + other = self._cache_other[cache_index] # type: ignore + + # load data from cache and execute from the first random transform + data = deepcopy(data) if self.copy_cache else data + yield from self._generate_patches(zip(data, other), start=self.first_random) + else: + for image in super().__iter__(): + yield from self._generate_patches(self.patch_iter(image)) + + +class PatchDataset(IterableDataset): """ - returns a patch from an image dataset. + Yields patches from data read from an image dataset. The patches are generated by a user-specified callable `patch_func`, and are optionally post-processed by `transform`. For example, to generate random patch samples from an image dataset: @@ -263,26 +415,26 @@ def __init__( samples_per_image: `patch_func` should return a sequence of `samples_per_image` elements. transform: transform applied to each patch. """ - super().__init__(data=data, transform=transform) + super().__init__(data=data, transform=None) self.patch_func = patch_func if samples_per_image <= 0: raise ValueError("sampler_per_image must be a positive integer.") self.samples_per_image = int(samples_per_image) + self.patch_transform = transform def __len__(self) -> int: - return len(self.data) * self.samples_per_image - - def _transform(self, index: int): - image_id = int(index / self.samples_per_image) - image = self.data[image_id] - patches = self.patch_func(image) - if len(patches) != self.samples_per_image: - raise RuntimeWarning( - f"`patch_func` must return a sequence of length: samples_per_image={self.samples_per_image}." - ) - patch_id = (index - image_id * self.samples_per_image) * (-1 if index < 0 else 1) - patch = patches[patch_id] - if self.transform is not None: - patch = apply_transform(self.transform, patch, map_items=False) - return patch + return len(self.data) * self.samples_per_image # type: ignore + + def __iter__(self): + for image in super().__iter__(): + patches = self.patch_func(image) + if len(patches) != self.samples_per_image: + raise RuntimeWarning( + f"`patch_func` must return a sequence of length: samples_per_image={self.samples_per_image}." + ) + for patch in patches: + out_patch = patch + if self.patch_transform is not None: + out_patch = apply_transform(self.patch_transform, patch, map_items=False) + yield out_patch diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index ba33547260..d937a5e266 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -108,11 +108,10 @@ def test_shape(self): self.assertEqual(sorted(output), sorted(expected)) def test_loading_array(self): - set_determinism(seed=1234) # test sequence input data with images images = [np.arange(16, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)] # image level - patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0) + patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0).set_random_state(seed=1234) patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0)) ds = GridPatchDataset(data=images, patch_iter=patch_iter, transform=patch_intensity) # use the grid patch dataset @@ -120,7 +119,7 @@ def test_loading_array(self): np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) np.testing.assert_allclose( item[0], - np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]), + np.array([[[[8.708934, 9.708934], [12.708934, 13.708934]]], [[[10.8683, 11.8683], [14.8683, 15.8683]]]]), rtol=1e-4, ) np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) @@ -129,9 +128,7 @@ def test_loading_array(self): np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) np.testing.assert_allclose( item[0], - np.array( - [[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]] - ), + np.array([[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]]), rtol=1e-3, ) np.testing.assert_allclose( @@ -164,7 +161,7 @@ def test_loading_dict(self): self.assertListEqual(item[0]["metadata"], ["test string", "test string"]) np.testing.assert_allclose( item[0]["image"], - np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]), + np.array([[[[8.708934, 9.708934], [12.708934, 13.708934]]], [[[10.8683, 11.8683], [14.8683, 15.8683]]]]), rtol=1e-4, ) np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) @@ -173,15 +170,53 @@ def test_loading_dict(self): np.testing.assert_equal(item[0]["image"].shape, (2, 1, 2, 2)) np.testing.assert_allclose( item[0]["image"], - np.array( - [[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]] - ), + np.array([[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]]), rtol=1e-3, ) np.testing.assert_allclose( item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5 ) + def test_set_data(self): + from monai.transforms import Compose, Lambda, RandLambda + + images = [np.arange(2, 18, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)] + + transform = Compose( + [Lambda(func=lambda x: np.array(x * 10)), RandLambda(func=lambda x: x + 1)], map_items=False + ) + patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0)) + dataset = GridPatchDataset( + data=images, + patch_iter=patch_iter, + transform=transform, + cache=True, + cache_rate=1.0, + copy_cache=not sys.platform == "linux", + ) + + num_workers = 2 if sys.platform == "linux" else 0 + for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers): + np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) + np.testing.assert_allclose(item[0], np.array([[[[81, 91], [121, 131]]], [[[101, 111], [141, 151]]]]), rtol=1e-4) + np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) + # simulate another epoch, the cache content should not be modified + for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers): + np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) + np.testing.assert_allclose(item[0], np.array([[[[81, 91], [121, 131]]], [[[101, 111], [141, 151]]]]), rtol=1e-4) + np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) + + # update the datalist and fill the cache content + data_list2 = [np.arange(1, 17, dtype=float).reshape(1, 4, 4)] + dataset.set_data(data=data_list2) + # rerun with updated cache content + for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers): + np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) + np.testing.assert_allclose( + item[0], np.array([[[[91, 101], [131, 141]]], [[[111, 121], [151, 161]]]]), rtol=1e-4 + ) + np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_patch_dataset.py b/tests/test_patch_dataset.py index 7d66bdccbb..eb705f0c61 100644 --- a/tests/test_patch_dataset.py +++ b/tests/test_patch_dataset.py @@ -37,7 +37,10 @@ def test_shape(self): n_workers = 0 if sys.platform == "win32" else 2 for item in DataLoader(result, batch_size=3, num_workers=n_workers): output.append("".join(item)) - expected = ["vwx", "yzh", "ell", "owo", "rld"] + if n_workers == 0: + expected = ["vwx", "yzh", "ell", "owo", "rld"] + else: + expected = ["vwx", "hel", "yzw", "lo", "orl", "d"] self.assertEqual(output, expected) def test_loading_array(self): @@ -61,7 +64,7 @@ def test_loading_array(self): np.testing.assert_allclose( item[0], np.array( - [[[-0.593095, 0.406905, 1.406905], [3.406905, 4.406905, 5.406905], [7.406905, 8.406905, 9.406905]]] + [[[4.970372, 5.970372, 6.970372], [8.970372, 9.970372, 10.970372], [12.970372, 13.970372, 14.970372]]] ), rtol=1e-5, ) @@ -71,7 +74,13 @@ def test_loading_array(self): np.testing.assert_allclose( item[0], np.array( - [[[0.234308, 1.234308, 2.234308], [4.234308, 5.234308, 6.234308], [8.234308, 9.234308, 10.234308]]] + [ + [ + [5.028125, 6.028125, 7.028125], + [9.028125, 10.028125, 11.028125], + [13.028125, 14.028125, 15.028125], + ] + ] ), rtol=1e-5, ) From 2ecb612bb9ad713d4f8e88d45491587b00c8e49c Mon Sep 17 00:00:00 2001 From: Ishan Dutta Date: Sun, 19 Nov 2023 21:49:16 +0530 Subject: [PATCH 14/67] :memo: [array] Add examples for EnsureType and CastToType (#7245) Fixes #7101 ### Description Added examples in the docstrings for `EnsureType` and `CastToType` transforms which show how they function under different circumstances. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Ishan Dutta Signed-off-by: Mark Graham --- monai/transforms/utility/array.py | 35 ++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 9aad12ef90..caf02d7b00 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -333,6 +333,23 @@ class CastToType(Transform): """ Cast the Numpy data to specified numpy data type, or cast the PyTorch Tensor to specified PyTorch data type. + + Example: + >>> import numpy as np + >>> import torch + >>> transform = CastToType(dtype=np.float32) + + >>> # Example with a numpy array + >>> img_np = np.array([0, 127, 255], dtype=np.uint8) + >>> img_np_casted = transform(img_np) + >>> img_np_casted + array([ 0. , 127. , 255. ], dtype=float32) + + >>> # Example with a PyTorch tensor + >>> img_tensor = torch.tensor([0, 127, 255], dtype=torch.uint8) + >>> img_tensor_casted = transform(img_tensor) + >>> img_tensor_casted + tensor([ 0., 127., 255.]) # dtype is float32 """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -413,10 +430,26 @@ class EnsureType(Transform): dtype: target data content type to convert, for example: np.float32, torch.float, etc. device: for Tensor data type, specify the target device. wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. - E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`. track_meta: if `True` convert to ``MetaTensor``, otherwise to Pytorch ``Tensor``, if ``None`` behave according to return value of py:func:`monai.data.meta_obj.get_track_meta`. + Example with wrap_sequence=True: + >>> import numpy as np + >>> import torch + >>> transform = EnsureType(data_type="tensor", wrap_sequence=True) + >>> # Converting a list to a tensor + >>> data_list = [1, 2., 3] + >>> tensor_data = transform(data_list) + >>> tensor_data + tensor([1., 2., 3.]) # All elements have dtype float32 + + Example with wrap_sequence=False: + >>> transform = EnsureType(data_type="tensor", wrap_sequence=False) + >>> # Converting each element in a list to individual tensors + >>> data_list = [1, 2, 3] + >>> tensors_list = transform(data_list) + >>> tensors_list + [tensor(1), tensor(2.), tensor(3)] # Only second element is float32 rest are int64 """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] From e9673e82364867d7ab10de55cc214067efe277ef Mon Sep 17 00:00:00 2001 From: Ishan Dutta Date: Mon, 20 Nov 2023 08:02:39 +0530 Subject: [PATCH 15/67] :hammer: [dataset] Handle corrupted cached file in PersistentDataset (#7244) Fixes #5723 ### Description Corrupted cached files in the PersistentDataset cause the exception: `RuntimeError: Invalid magic number; corrupt file?` With this PR we handle that case in the try-except block and continue the usual functionality if the cached file was absent. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Ishan Dutta Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/data/dataset.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 5e403d6fdb..eba850225d 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -387,6 +387,12 @@ def _cachecheck(self, item_transformed): except PermissionError as e: if sys.platform != "win32": raise e + except RuntimeError as e: + if "Invalid magic number; corrupt file" in str(e): + warnings.warn(f"Corrupt cache file detected: {hashfile}. Deleting and recomputing.") + hashfile.unlink() + else: + raise e _item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed if hashfile is None: From fcb4ff468fcc2ef17bcedb28806f846dce6ab3b3 Mon Sep 17 00:00:00 2001 From: monai-bot <64792179+monai-bot@users.noreply.github.com> Date: Mon, 20 Nov 2023 08:34:18 +0000 Subject: [PATCH 16/67] auto updates (#7247) Signed-off-by: monai-bot Signed-off-by: monai-bot Signed-off-by: Mark Graham --- monai/apps/detection/transforms/box_ops.py | 2 +- monai/data/grid_dataset.py | 4 ++-- monai/data/image_writer.py | 2 +- monai/data/wsi_reader.py | 2 +- monai/metrics/utils.py | 2 +- monai/networks/nets/resnet.py | 6 +++--- monai/transforms/croppad/array.py | 8 +++----- monai/transforms/spatial/array.py | 2 +- monai/transforms/utility/array.py | 2 +- monai/transforms/utils.py | 8 ++++---- tests/test_inverse.py | 2 +- tests/utils.py | 4 ++-- 12 files changed, 21 insertions(+), 23 deletions(-) diff --git a/monai/apps/detection/transforms/box_ops.py b/monai/apps/detection/transforms/box_ops.py index fb870c952e..404854c4c0 100644 --- a/monai/apps/detection/transforms/box_ops.py +++ b/monai/apps/detection/transforms/box_ops.py @@ -407,7 +407,7 @@ def rot90_boxes( spatial_dims: int = get_spatial_dims(boxes=boxes) spatial_size_ = list(ensure_tuple_rep(spatial_size, spatial_dims)) - axes = ensure_tuple(axes) # type: ignore + axes = ensure_tuple(axes) if len(axes) != 2: raise ValueError("len(axes) must be 2.") diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 9079032e6f..689138179a 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -352,8 +352,8 @@ def __iter__(self): raise RuntimeError( "Cache buffer is not initialized, please call `set_data()` before epoch begins." ) - data = self._cache[cache_index] # type: ignore - other = self._cache_other[cache_index] # type: ignore + data = self._cache[cache_index] + other = self._cache_other[cache_index] # load data from cache and execute from the first random transform data = deepcopy(data) if self.copy_cache else data diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index db0bfa96b8..b9e8b9e68e 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -276,7 +276,7 @@ def resample_if_needed( # convert back at the end if isinstance(output_array, MetaTensor): output_array.applied_operations = [] - data_array, *_ = convert_data_type(output_array, output_type=orig_type) # type: ignore + data_array, *_ = convert_data_type(output_array, output_type=orig_type) affine, *_ = convert_data_type(output_array.affine, output_type=orig_type) # type: ignore return data_array[0], affine diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 54e12eb0cd..b31d4d9c3a 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -111,7 +111,7 @@ def __init__( self.set_device(device) self.mode = mode self.kwargs = kwargs - self.mpp: tuple[float, float] | None = ensure_tuple_rep(mpp, 2) if mpp is not None else None # type: ignore + self.mpp: tuple[float, float] | None = ensure_tuple_rep(mpp, 2) if mpp is not None else None self.power = power self.mpp_rtol = mpp_rtol self.mpp_atol = mpp_atol diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 4d4e6570c5..62e6520b96 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -205,7 +205,7 @@ def get_mask_edges( or_vol = seg_pred | seg_gt if not or_vol.any(): pred, gt = lib.zeros(seg_pred.shape, dtype=bool), lib.zeros(seg_gt.shape, dtype=bool) - return (pred, gt) if spacing is None else (pred, gt, pred, gt) # type: ignore + return (pred, gt) if spacing is None else (pred, gt, pred, gt) channel_first = [seg_pred[None], seg_gt[None], or_vol[None]] if spacing is None and not use_cucim: # cpu only erosion seg_pred, seg_gt, or_vol = convert_to_tensor(channel_first, device="cpu", dtype=bool) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index fca73f4de3..34a4b7057e 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -239,9 +239,9 @@ def __init__( self.conv1 = conv_type( n_input_channels, self.in_planes, - kernel_size=conv1_kernel_size, # type: ignore - stride=conv1_stride, # type: ignore - padding=tuple(k // 2 for k in conv1_kernel_size), # type: ignore + kernel_size=conv1_kernel_size, + stride=conv1_stride, + padding=tuple(k // 2 for k in conv1_kernel_size), bias=False, ) self.bn1 = norm_type(self.in_planes) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 6a3798e7ba..ce3701b263 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -386,7 +386,7 @@ def compute_slices( if roi_slices: if not all(s.step is None or s.step == 1 for s in roi_slices): raise ValueError(f"only slice steps of 1/None are currently supported, got {roi_slices}.") - return ensure_tuple(roi_slices) # type: ignore + return ensure_tuple(roi_slices) else: if roi_center is not None and roi_size is not None: roi_center_t = convert_to_tensor(data=roi_center, dtype=torch.int16, wrap_sequence=True, device="cpu") @@ -408,10 +408,8 @@ def compute_slices( roi_end_t = torch.maximum(roi_end_t, roi_start_t) # convert to slices (accounting for 1d) if roi_start_t.numel() == 1: - return ensure_tuple([slice(int(roi_start_t.item()), int(roi_end_t.item()))]) # type: ignore - return ensure_tuple( # type: ignore - [slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())] - ) + return ensure_tuple([slice(int(roi_start_t.item()), int(roi_end_t.item()))]) + return ensure_tuple([slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())]) def __call__( # type: ignore[override] self, img: torch.Tensor, slices: tuple[slice, ...], lazy: bool | None = None diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 9d55aa013b..8ad86b72dd 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1157,7 +1157,7 @@ def __init__(self, k: int = 1, spatial_axes: tuple[int, int] = (0, 1), lazy: boo """ LazyTransform.__init__(self, lazy=lazy) self.k = (4 + (k % 4)) % 4 # 0, 1, 2, 3 - spatial_axes_: tuple[int, int] = ensure_tuple(spatial_axes) # type: ignore + spatial_axes_: tuple[int, int] = ensure_tuple(spatial_axes) if len(spatial_axes_) != 2: raise ValueError(f"spatial_axes must be 2 numbers to define the plane to rotate, got {spatial_axes_}.") self.spatial_axes = spatial_axes_ diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index caf02d7b00..2322f2123f 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -372,7 +372,7 @@ def __call__(self, img: NdarrayOrTensor, dtype: DtypeLike | torch.dtype = None) TypeError: When ``img`` type is not in ``Union[numpy.ndarray, torch.Tensor]``. """ - return convert_data_type(img, output_type=type(img), dtype=dtype or self.dtype)[0] # type: ignore + return convert_data_type(img, output_type=type(img), dtype=dtype or self.dtype)[0] class ToTensor(Transform): diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 678219991f..e282ecff24 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -521,7 +521,7 @@ def correct_crop_centers( for c, v_s, v_e in zip(centers, valid_start, valid_end): center_i = min(max(c, v_s), v_e - 1) valid_centers.append(int(center_i)) - return ensure_tuple(valid_centers) # type: ignore + return ensure_tuple(valid_centers) def generate_pos_neg_label_crop_centers( @@ -579,7 +579,7 @@ def generate_pos_neg_label_crop_centers( # shift center to range of valid centers centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape, allow_smaller)) - return ensure_tuple(centers) # type: ignore + return ensure_tuple(centers) def generate_label_classes_crop_centers( @@ -639,7 +639,7 @@ def generate_label_classes_crop_centers( # shift center to range of valid centers centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape, allow_smaller)) - return ensure_tuple(centers) # type: ignore + return ensure_tuple(centers) def create_grid( @@ -2218,7 +2218,7 @@ def distance_transform_edt( if not r_vals: return None device = img.device if isinstance(img, torch.Tensor) else None - return convert_data_type(r_vals[0] if len(r_vals) == 1 else r_vals, output_type=type(img), device=device)[0] # type: ignore + return convert_data_type(r_vals[0] if len(r_vals) == 1 else r_vals, output_type=type(img), device=device)[0] if __name__ == "__main__": diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 3f07b43d6d..6bd14a19f1 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -310,7 +310,7 @@ TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], t[3], Compose(Compose(t[4:]))) for t in TESTS] -TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore +TESTS = TESTS + TESTS_COMPOSE_X2 NUM_SAMPLES = 5 N_SAMPLES_TESTS = [ diff --git a/tests/utils.py b/tests/utils.py index cf1711292f..ee800598bb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -832,9 +832,9 @@ def equal_state_dict(st_1, st_2): [[2.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 0.0, 2.0, 0.0], [0.0, 0.0, 0.0, 1.0]] ) _metatensor_creator = partial(MetaTensor, meta={"a": "b", "affine": DEFAULT_TEST_AFFINE}) -TEST_NDARRAYS_NO_META_TENSOR: tuple[Callable] = (np.array,) + TEST_TORCH_TENSORS # type: ignore +TEST_NDARRAYS_NO_META_TENSOR: tuple[Callable] = (np.array,) + TEST_TORCH_TENSORS TEST_NDARRAYS: tuple[Callable] = TEST_NDARRAYS_NO_META_TENSOR + (_metatensor_creator,) # type: ignore -TEST_TORCH_AND_META_TENSORS: tuple[Callable] = TEST_TORCH_TENSORS + (_metatensor_creator,) # type: ignore +TEST_TORCH_AND_META_TENSORS: tuple[Callable] = TEST_TORCH_TENSORS + (_metatensor_creator,) # alias for branch tests TEST_NDARRAYS_ALL = TEST_NDARRAYS From a99ea956d50ea8f592e7999e46214dc2acc8ced7 Mon Sep 17 00:00:00 2001 From: elitap Date: Wed, 22 Nov 2023 06:06:17 +0100 Subject: [PATCH 17/67] =?UTF-8?q?add=20class=20label=20option=20to=20write?= =?UTF-8?q?=20metric=20report=20to=20improve=20readability=20=E2=80=A6=20(?= =?UTF-8?q?#7249)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit add class label option to write metric report to improve readability, without that option in case of many classes the resulting report is very hard to interpret. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). --------- Signed-off-by: elitap Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/handlers/utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 58a3fd36f3..0cd31b89c2 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -61,6 +61,7 @@ def write_metrics_reports( summary_ops: str | Sequence[str] | None, deli: str = ",", output_type: str = "csv", + class_labels: list[str] | None = None, ) -> None: """ Utility function to write the metrics into files, contains 3 parts: @@ -94,6 +95,8 @@ class mean median max 5percentile 95percentile notnans deli: the delimiter character in the saved file, default to "," as the default output type is `csv`. to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter. output_type: expected output file type, supported types: ["csv"], default to "csv". + class_labels: list of class names used to name the classes in the output report, if None, + "class0", ..., "classn" are used, default to None. """ if output_type.lower() != "csv": @@ -118,7 +121,12 @@ class mean median max 5percentile 95percentile notnans v = v.reshape((-1, 1)) # add the average value of all classes to v - class_labels = ["class" + str(i) for i in range(v.shape[1])] + ["mean"] + if class_labels is None: + class_labels = ["class" + str(i) for i in range(v.shape[1])] + else: + class_labels = [str(i) for i in class_labels] # ensure to have a list of str + + class_labels += ["mean"] v = np.concatenate([v, np.nanmean(v, axis=1, keepdims=True)], axis=1) with open(os.path.join(save_dir, f"{k}_raw.csv"), "w") as f: From ee68d507c52095ffdc3f56d5b0c49a42ffec4d4c Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 28 Nov 2023 00:18:51 +0800 Subject: [PATCH 18/67] Fix B026 unrecommanded star-arg unpacking after a keyword argument (#7262) Fixes #7261 ### Description Remove star-arg unpacking before a keyword argument. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: KumoLiu Signed-off-by: Mark Graham --- monai/data/image_reader.py | 2 +- monai/inferers/inferer.py | 2 +- tests/test_video_datasets.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index fe199d9570..0823d11834 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1300,7 +1300,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: - nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, *kwargs_)) + nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, **kwargs_)) img_.append(nrrd_image) return img_ if len(filenames) > 1 else img_[0] diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index bf8c27e5c3..0b4199938d 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -584,10 +584,10 @@ def __call__( return super().__call__( inputs, network, + *args, device=inputs.device if gpu_stitching else torch.device("cpu"), buffer_steps=buffer_steps if buffered_stitching else None, buffer_dim=buffer_dim, - *args, **kwargs, ) except RuntimeError as e: diff --git a/tests/test_video_datasets.py b/tests/test_video_datasets.py index eedbe212eb..790feb51ee 100644 --- a/tests/test_video_datasets.py +++ b/tests/test_video_datasets.py @@ -39,7 +39,7 @@ def get_video_source(self): return self.video_source def get_ds(self, *args, **kwargs) -> VideoDataset: - return self.ds(video_source=self.get_video_source(), transform=TRANSFORMS, *args, **kwargs) # type: ignore + return self.ds(*args, video_source=self.get_video_source(), transform=TRANSFORMS, **kwargs) # type: ignore @unittest.skipIf(has_cv2, "Only tested when OpenCV not installed.") def test_no_opencv_raises(self): From 62ad35c4ac0e5087154207ec56e573b75d8db8ba Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Thu, 30 Nov 2023 10:40:00 +0800 Subject: [PATCH 19/67] Quote $PY_EXE variable to deal with Python path that contain spaces in Bash (#7268) Fixes #5857. ### Description When dealing with paths that contain spaces in Bash, it's important to properly quote the variables to ensure that spaces are handled correctly. So, maybe we can replace all `$PY_EXE` variables to `"$PY_EXE"` in the `runtests.sh` file. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: ytl0623 Signed-off-by: Mark Graham --- runtests.sh | 70 ++++++++++++++++++++++++++--------------------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/runtests.sh b/runtests.sh index cfceb6976a..0c60bc0f58 100755 --- a/runtests.sh +++ b/runtests.sh @@ -119,42 +119,42 @@ function print_usage { } # FIXME: https://github.com/Project-MONAI/MONAI/issues/4354 -protobuf_major_version=$(${PY_EXE} -m pip list | grep '^protobuf ' | tr -s ' ' | cut -d' ' -f2 | cut -d'.' -f1) +protobuf_major_version=$("${PY_EXE}" -m pip list | grep '^protobuf ' | tr -s ' ' | cut -d' ' -f2 | cut -d'.' -f1) if [ "$protobuf_major_version" -ge "4" ] then export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python fi function check_import { - echo "Python: ${PY_EXE}" - ${cmdPrefix}${PY_EXE} -W error -W ignore::DeprecationWarning -W ignore::ResourceWarning -c "import monai" + echo "Python: "${PY_EXE}"" + ${cmdPrefix}"${PY_EXE}" -W error -W ignore::DeprecationWarning -W ignore::ResourceWarning -c "import monai" } function print_version { - ${cmdPrefix}${PY_EXE} -c 'import monai; monai.config.print_config()' # project-monai/monai#6167 + ${cmdPrefix}"${PY_EXE}" -c 'import monai; monai.config.print_config()' # project-monai/monai#6167 } function install_deps { echo "Pip installing MONAI development dependencies and compile MONAI cpp extensions..." - ${cmdPrefix}${PY_EXE} -m pip install -r requirements-dev.txt + ${cmdPrefix}"${PY_EXE}" -m pip install -r requirements-dev.txt } function compile_cpp { echo "Compiling and installing MONAI cpp extensions..." # depends on setup.py behaviour for building # currently setup.py uses environment variables: BUILD_MONAI and FORCE_CUDA - ${cmdPrefix}${PY_EXE} setup.py develop --user --uninstall + ${cmdPrefix}"${PY_EXE}" setup.py develop --user --uninstall if [[ "$OSTYPE" == "darwin"* ]]; then # clang for mac os - CC=clang CXX=clang++ ${cmdPrefix}${PY_EXE} setup.py develop --user + CC=clang CXX=clang++ ${cmdPrefix}"${PY_EXE}" setup.py develop --user else - ${cmdPrefix}${PY_EXE} setup.py develop --user + ${cmdPrefix}"${PY_EXE}" setup.py develop --user fi } function clang_format { echo "Running clang-format..." - ${cmdPrefix}${PY_EXE} -m tests.clang_format_utils + ${cmdPrefix}"${PY_EXE}" -m tests.clang_format_utils clang_format_tool='.clang-format-bin/clang-format' # Verify . if ! type -p "$clang_format_tool" >/dev/null; then @@ -167,19 +167,19 @@ function clang_format { } function is_pip_installed() { - return $(${PY_EXE} -c "import sys, pkgutil; sys.exit(0 if pkgutil.find_loader(sys.argv[1]) else 1)" $1) + return $("${PY_EXE}" -c "import sys, pkgutil; sys.exit(0 if pkgutil.find_loader(sys.argv[1]) else 1)" $1) } function clean_py { if is_pip_installed coverage then # remove coverage history - ${cmdPrefix}${PY_EXE} -m coverage erase + ${cmdPrefix}"${PY_EXE}" -m coverage erase fi # uninstall the development package echo "Uninstalling MONAI development files..." - ${cmdPrefix}${PY_EXE} setup.py develop --user --uninstall + ${cmdPrefix}"${PY_EXE}" setup.py develop --user --uninstall # remove temporary files (in the directory of this script) TO_CLEAN="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" @@ -201,7 +201,7 @@ function clean_py { } function torch_validate { - ${cmdPrefix}${PY_EXE} -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' + ${cmdPrefix}"${PY_EXE}" -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' } function print_error_msg() { @@ -219,7 +219,7 @@ function print_style_fail_msg() { } function list_unittests() { - ${PY_EXE} - << END + "${PY_EXE}" - << END import unittest def print_suite(suite): if hasattr(suite, "__iter__"): @@ -448,7 +448,7 @@ then then install_deps fi - ${cmdPrefix}${PY_EXE} -m pre_commit run --all-files + ${cmdPrefix}"${PY_EXE}" -m pre_commit run --all-files pre_commit_status=$? if [ ${pre_commit_status} -ne 0 ] @@ -477,13 +477,13 @@ then then install_deps fi - ${cmdPrefix}${PY_EXE} -m isort --version + ${cmdPrefix}"${PY_EXE}" -m isort --version if [ $doIsortFix = true ] then - ${cmdPrefix}${PY_EXE} -m isort "$homedir" + ${cmdPrefix}"${PY_EXE}" -m isort "$homedir" else - ${cmdPrefix}${PY_EXE} -m isort --check "$homedir" + ${cmdPrefix}"${PY_EXE}" -m isort --check "$homedir" fi isort_status=$? @@ -513,13 +513,13 @@ then then install_deps fi - ${cmdPrefix}${PY_EXE} -m black --version + ${cmdPrefix}"${PY_EXE}" -m black --version if [ $doBlackFix = true ] then - ${cmdPrefix}${PY_EXE} -m black --skip-magic-trailing-comma "$homedir" + ${cmdPrefix}"${PY_EXE}" -m black --skip-magic-trailing-comma "$homedir" else - ${cmdPrefix}${PY_EXE} -m black --skip-magic-trailing-comma --check "$homedir" + ${cmdPrefix}"${PY_EXE}" -m black --skip-magic-trailing-comma --check "$homedir" fi black_status=$? @@ -544,9 +544,9 @@ then then install_deps fi - ${cmdPrefix}${PY_EXE} -m flake8 --version + ${cmdPrefix}"${PY_EXE}" -m flake8 --version - ${cmdPrefix}${PY_EXE} -m flake8 "$homedir" --count --statistics + ${cmdPrefix}"${PY_EXE}" -m flake8 "$homedir" --count --statistics flake8_status=$? if [ ${flake8_status} -ne 0 ] @@ -568,12 +568,12 @@ then if ! is_pip_installed pylint then echo "Pip installing pylint ..." - ${cmdPrefix}${PY_EXE} -m pip install "pylint>2.16,!=3.0.0" + ${cmdPrefix}"${PY_EXE}" -m pip install "pylint>2.16,!=3.0.0" fi - ${cmdPrefix}${PY_EXE} -m pylint --version + ${cmdPrefix}"${PY_EXE}" -m pylint --version ignore_codes="C,R,W,E1101,E1102,E0601,E1130,E1123,E0102,E1120,E1137,E1136" - ${cmdPrefix}${PY_EXE} -m pylint monai tests --disable=$ignore_codes -j $NUM_PARALLEL + ${cmdPrefix}"${PY_EXE}" -m pylint monai tests --disable=$ignore_codes -j $NUM_PARALLEL pylint_status=$? if [ ${pylint_status} -ne 0 ] @@ -632,14 +632,14 @@ then then install_deps fi - pytype_ver=$(${cmdPrefix}${PY_EXE} -m pytype --version) + pytype_ver=$(${cmdPrefix}"${PY_EXE}" -m pytype --version) if [[ "$OSTYPE" == "darwin"* && "$pytype_ver" == "2021."* ]]; then echo "${red}pytype not working on macOS 2021 (https://github.com/Project-MONAI/MONAI/issues/2391). Please upgrade to 2022*.${noColor}" exit 1 else - ${cmdPrefix}${PY_EXE} -m pytype --version + ${cmdPrefix}"${PY_EXE}" -m pytype --version - ${cmdPrefix}${PY_EXE} -m pytype -j ${NUM_PARALLEL} --python-version="$(${PY_EXE} -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")" "$homedir" + ${cmdPrefix}"${PY_EXE}" -m pytype -j ${NUM_PARALLEL} --python-version="$(${PY_EXE} -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")" "$homedir" pytype_status=$? if [ ${pytype_status} -ne 0 ] @@ -664,8 +664,8 @@ then then install_deps fi - ${cmdPrefix}${PY_EXE} -m mypy --version - ${cmdPrefix}${PY_EXE} -m mypy "$homedir" + ${cmdPrefix}"${PY_EXE}" -m mypy --version + ${cmdPrefix}"${PY_EXE}" -m mypy "$homedir" mypy_status=$? if [ ${mypy_status} -ne 0 ] @@ -695,7 +695,7 @@ if [ $doMinTests = true ] then echo "${separator}${blue}min${noColor}" doCoverage=false - ${cmdPrefix}${PY_EXE} -m tests.min_tests + ${cmdPrefix}"${PY_EXE}" -m tests.min_tests fi # set coverage command @@ -707,7 +707,7 @@ then then install_deps fi - cmd="${PY_EXE} -m coverage run --append" + cmd=""${PY_EXE}" -m coverage run --append" fi # # download test data if needed @@ -763,6 +763,6 @@ then then install_deps fi - ${cmdPrefix}${PY_EXE} -m coverage combine --append .coverage/ - ${cmdPrefix}${PY_EXE} -m coverage report --ignore-errors + ${cmdPrefix}"${PY_EXE}" -m coverage combine --append .coverage/ + ${cmdPrefix}"${PY_EXE}" -m coverage report --ignore-errors fi From 9ced3e5f5dbf17e28abbd4b6950b6a2b26a1fa96 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Thu, 30 Nov 2023 12:24:11 +0800 Subject: [PATCH 20/67] add SoftclDiceLoss and SoftDiceclDiceLoss loss function in documentation (#7271) Fixes #7193 ### Description add SoftclDiceLoss and SoftDiceclDiceLoss loss function in documentation(https://github.com/Project-MONAI/MONAI/blob/dev/docs/source/losses.rst?plain=1) ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: ytl0623 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- docs/source/losses.rst | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 5d488afbb3..568c7dfc77 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -78,6 +78,16 @@ Segmentation Losses .. autoclass:: HausdorffDTLoss :members: +`SoftclDiceLoss` +~~~~~~~~~~~~~~~~ +.. autoclass:: SoftclDiceLoss + :members: + +`SoftDiceclDiceLoss` +~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: SoftDiceclDiceLoss + :members: + Registration Losses ------------------- From 8642877f87ce6ae7f67feacca67ef4ce116f1609 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 30 Nov 2023 20:10:45 +0800 Subject: [PATCH 21/67] Skip Old Pytorch Versions for `SwinUNETR` (#7266) Fixes #7265. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu Signed-off-by: Mark Graham --- tests/test_swin_unetr.py | 46 ++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py index 9f6b1e7c0a..e34e5a3c8e 100644 --- a/tests/test_swin_unetr.py +++ b/tests/test_swin_unetr.py @@ -24,13 +24,21 @@ from monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR, filter_swinunetr from monai.networks.utils import copy_model_state from monai.utils import optional_import -from tests.utils import assert_allclose, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick, testing_data_config +from tests.utils import ( + assert_allclose, + pytorch_after, + skip_if_downloading_fails, + skip_if_no_cuda, + skip_if_quick, + testing_data_config, +) einops, has_einops = optional_import("einops") TEST_CASE_SWIN_UNETR = [] case_idx = 0 test_merging_mode = ["mergingv2", "merging", PatchMerging, PatchMergingV2] +checkpoint_vals = [True, False] if pytorch_after(1, 11) else [False] for attn_drop_rate in [0.4]: for in_channels in [1]: for depth in [[2, 1, 1, 1], [1, 2, 1, 1]]: @@ -38,23 +46,25 @@ for img_size in ((64, 32, 192), (96, 32)): for feature_size in [12]: for norm_name in ["instance"]: - test_case = [ - { - "spatial_dims": len(img_size), - "in_channels": in_channels, - "out_channels": out_channels, - "img_size": img_size, - "feature_size": feature_size, - "depths": depth, - "norm_name": norm_name, - "attn_drop_rate": attn_drop_rate, - "downsample": test_merging_mode[case_idx % 4], - }, - (2, in_channels, *img_size), - (2, out_channels, *img_size), - ] - case_idx += 1 - TEST_CASE_SWIN_UNETR.append(test_case) + for use_checkpoint in checkpoint_vals: + test_case = [ + { + "spatial_dims": len(img_size), + "in_channels": in_channels, + "out_channels": out_channels, + "img_size": img_size, + "feature_size": feature_size, + "depths": depth, + "norm_name": norm_name, + "attn_drop_rate": attn_drop_rate, + "downsample": test_merging_mode[case_idx % 4], + "use_checkpoint": use_checkpoint, + }, + (2, in_channels, *img_size), + (2, out_channels, *img_size), + ] + case_idx += 1 + TEST_CASE_SWIN_UNETR.append(test_case) TEST_CASE_FILTER = [ [ From 708167a44af43bf833f6419c7de8111e9a21dcca Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 1 Dec 2023 16:30:30 +0800 Subject: [PATCH 22/67] Bump conda-incubator/setup-miniconda from 2 to 3 (#7274) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [conda-incubator/setup-miniconda](https://github.com/conda-incubator/setup-miniconda) from 2 to 3.
Release notes

Sourced from conda-incubator/setup-miniconda's releases.

Version 3.0.0

Features

  • #308 Update to node20
  • #291 Add conda-solver option (defaults to libmamba)

Fixes

  • #299 Fix condaBasePath when useBundled is false, and there's no pre-existing conda

Documentation

  • #309 Switch to main branch based development
  • #313 Specify team conda-incubator/setup-miniconda as codeowners
  • #318 README: update actions in examples, add security section, similar actions

Tasks and Maintenance

  • #307 Run dependabot against main branch and also update node packages
  • #311 Bump actions/checkout from 2 to 4
  • #310 Bump actions/cache from 1 to 3
  • #314 Strip/update dependencies
  • #315 Split lint into check and build, switch from npm install to npm ci
  • #317 Bump normalize-url from 4.5.1 to 8.0.0
  • #316 Faster workflow response / saving resources via timeout/concurrency policy

#308: conda-incubator/setup-miniconda#308 #291: conda-incubator/setup-miniconda#291 #299: conda-incubator/setup-miniconda#299 #309: conda-incubator/setup-miniconda#309 #313: conda-incubator/setup-miniconda#313 #318: conda-incubator/setup-miniconda#318 #307: conda-incubator/setup-miniconda#307 #311: conda-incubator/setup-miniconda#311 #310: conda-incubator/setup-miniconda#310 #314: conda-incubator/setup-miniconda#314 #315: conda-incubator/setup-miniconda#315 #317: conda-incubator/setup-miniconda#317 #316: conda-incubator/setup-miniconda#316

New Contributors

Full Changelog: https://github.com/conda-incubator/setup-miniconda/compare/v2...v3.0.0

Version 2.3.0

Documentation

  • #263 Update links to GitHub shell docs

... (truncated)

Changelog

Sourced from conda-incubator/setup-miniconda's changelog.

v3.0.1 (2023-11-29)

Fixes

  • #325 Fix environment activation on windows (a v3 regression) due to hard-coded install PATH

#325: conda-incubator/setup-miniconda#325

v3.0.0 (2023-11-27)

Features

  • #308 Update to node20
  • #291 Add conda-solver option (defaults to libmamba)

Fixes

  • #299 Fix condaBasePath when useBundled is false, and there's no pre-existing conda

Documentation

  • #309 Switch to main branch based development
  • #313 Specify team conda-incubator/setup-miniconda as codeowners
  • #318 README: update actions in examples, add security section, similar actions

Tasks and Maintenance

  • #307 Run dependabot against main branch and also update node packages
  • #311 Bump actions/checkout from 2 to 4
  • #310 Bump actions/cache from 1 to 3
  • #314 Strip/update dependencies
  • #315 Split lint into check and build, switch from npm install to npm ci
  • #317 Bump normalize-url from 4.5.1 to 8.0.0
  • #316 Faster workflow response / saving resources via timeout/concurrency policy

#308: conda-incubator/setup-miniconda#308 #291: conda-incubator/setup-miniconda#291 #299: conda-incubator/setup-miniconda#299 #309: conda-incubator/setup-miniconda#309 #313: conda-incubator/setup-miniconda#313 #318: conda-incubator/setup-miniconda#318 #307: conda-incubator/setup-miniconda#307 #311: conda-incubator/setup-miniconda#311 #310: conda-incubator/setup-miniconda#310

... (truncated)

Commits
  • 11b5629 Prepare 3.0.1 (#326)
  • 8706aa7 Fix env activation on win (a v3 regression) due to hard-coded install PATH (#...
  • c585a97 Bump conda-incubator/setup-miniconda from 2.3.0 to 3.0.0 (#321)
  • 2defc80 Prepare release (#320)
  • 0d5a56b Bump actions/checkout from 2 to 4 (#319)
  • 45fd3f9 Merge pull request #316 from dbast/timeout
  • d1e04fc Merge pull request #299 from isuruf/condaBasePath
  • fab0073 Merge pull request #318 from dbast/readme
  • fa6bdf9 Update with npm run build
  • d42f8b8 Fix condaBasePath when useBundled is false, and there's no pre-existing conda
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=conda-incubator/setup-miniconda&package-manager=github_actions&previous-version=2&new-version=3)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: Mark Graham --- .github/workflows/conda.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/conda.yml b/.github/workflows/conda.yml index dc07f833be..a387c77ebd 100644 --- a/.github/workflows/conda.yml +++ b/.github/workflows/conda.yml @@ -32,7 +32,7 @@ jobs: maximum-size: 16GB disk-root: "D:" - uses: actions/checkout@v4 - - uses: conda-incubator/setup-miniconda@v2 + - uses: conda-incubator/setup-miniconda@v3 with: auto-update-conda: true python-version: ${{ matrix.python-version }} From 039ee21316469ddb8362ff3edd779d39dbe13a0e Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 5 Dec 2023 07:55:16 -0500 Subject: [PATCH 23/67] 6676 port generative networks autoencoderkl (#7260) Partially fixes #6676 ### Description Implements the AutoencoderKL network from MONAI Generative. NB this network is subject to a planned refactor once the porting is complete, [see here](https://github.com/Project-MONAI/MONAI/issues/7227). ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- docs/source/networks.rst | 5 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/autoencoderkl.py | 807 +++++++++++++++++++++++++++ tests/test_autoencoderkl.py | 276 +++++++++ 4 files changed, 1089 insertions(+) create mode 100644 monai/networks/nets/autoencoderkl.py create mode 100644 tests/test_autoencoderkl.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 8eada7933f..dbfdf35784 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -595,6 +595,11 @@ Nets .. autoclass:: AutoEncoder :members: +`AutoEncoderKL` +~~~~~~~~~~~~~~~ +.. autoclass:: AutoencoderKL + :members: + `VarAutoEncoder` ~~~~~~~~~~~~~~~~ .. autoclass:: VarAutoEncoder diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 9247aaee85..ea08246d25 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -14,6 +14,7 @@ from .ahnet import AHnet, Ahnet, AHNet from .attentionunet import AttentionUnet from .autoencoder import AutoEncoder +from .autoencoderkl import AutoencoderKL from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus from .classifier import Classifier, Critic, Discriminator diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py new file mode 100644 index 0000000000..9a9f35d5ae --- /dev/null +++ b/monai/networks/nets/autoencoderkl.py @@ -0,0 +1,807 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from collections.abc import Sequence +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import Convolution + +# To install xformers, use pip install xformers==0.0.16rc401 +from monai.utils import ensure_tuple_rep, optional_import + +xformers, has_xformers = optional_import("xformers") + +__all__ = ["AutoencoderKL"] + + +class _Upsample(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Convolution-based upsampling layer. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels to the layer. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool) -> None: + super().__init__() + if use_convtranspose: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=2, + kernel_size=3, + padding=1, + conv_only=True, + is_transposed=True, + ) + else: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.use_convtranspose = use_convtranspose + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_convtranspose: + conv: torch.Tensor = self.conv(x) + return conv + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # https://github.com/pytorch/pytorch/issues/86679 + dtype = x.dtype + if dtype == torch.bfloat16: + x = x.to(torch.float32) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + x = x.to(dtype) + + x = self.conv(x) + return x + + +class _Downsample(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Convolution-based downsampling layer. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + """ + + def __init__(self, spatial_dims: int, in_channels: int) -> None: + super().__init__() + self.pad = (0, 1) * spatial_dims + + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=2, + kernel_size=3, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = nn.functional.pad(x, self.pad, mode="constant", value=0.0) + x = self.conv(x) + return x + + +class _ResBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a + residual connection between input and output. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: input channels to the layer. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon for the normalisation. + out_channels: number of output channels. + """ + + def __init__( + self, spatial_dims: int, in_channels: int, norm_num_groups: int, norm_eps: float, out_channels: int + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True) + self.conv2 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.nin_shortcut: nn.Module + if self.in_channels != self.out_channels: + self.nin_shortcut = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + else: + self.nin_shortcut = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = F.silu(h) + h = self.conv1(h) + + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + + x = self.nin_shortcut(x) + + return x + h + + +class _AttentionBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Attention block. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + num_channels: number of input channels. + num_head_channels: number of channels in each attention head. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon value to use for the normalisation. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: int, + num_head_channels: int | None = None, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.use_flash_attention = use_flash_attention + self.spatial_dims = spatial_dims + self.num_channels = num_channels + + self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 + self.scale = 1 / math.sqrt(num_channels / self.num_heads) + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) + + self.to_q = nn.Linear(num_channels, num_channels) + self.to_k = nn.Linear(num_channels, num_channels) + self.to_v = nn.Linear(num_channels, num_channels) + + self.proj_attn = nn.Linear(num_channels, num_channels) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + """ + Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. + """ + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + """Combine the output of the attention heads back into the hidden state dimension.""" + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x: torch.Tensor = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + x = torch.bmm(attention_probs, value) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + # norm + x = self.norm(x) + + if self.spatial_dims == 2: + x = x.view(batch, channel, height * width).transpose(1, 2) + if self.spatial_dims == 3: + x = x.view(batch, channel, height * width * depth).transpose(1, 2) + + # proj to q, k, v + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if self.use_flash_attention: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + + if self.spatial_dims == 2: + x = x.transpose(-1, -2).reshape(batch, channel, height, width) + if self.spatial_dims == 3: + x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) + + return x + residual + + +class Encoder(nn.Module): + """ + Convolutional cascade that downsamples the image into a spatial latent space. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + channels: sequence of block output channels. + out_channels: number of channels in the bottom layer (latent space) of the autoencoder. + num_res_blocks: number of residual blocks (see _ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from num_channels contain an attention block. + with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + channels: Sequence[int], + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.channels = channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + + blocks: List[nn.Module] = [] + # Initial convolution + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Residual and downsampling blocks + output_channel = channels[0] + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_final_block = i == len(channels) - 1 + + for _ in range(self.num_res_blocks[i]): + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=input_channel, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=output_channel, + ) + ) + input_channel = output_channel + if attention_levels[i]: + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=input_channel, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append(_Downsample(spatial_dims=spatial_dims, in_channels=input_channel)) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=channels[-1], + ) + ) + + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=channels[-1], + ) + ) + # Normalise and convert to latent size + blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[-1], eps=norm_eps, affine=True)) + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=channels[-1], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class Decoder(nn.Module): + """ + Convolutional cascade upsampling from a spatial latent space into an image space. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + channels: sequence of block output channels. + in_channels: number of channels in the bottom layer (latent space) of the autoencoder. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from num_channels contain an attention block. + with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__( + self, + spatial_dims: int, + channels: Sequence[int], + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + use_convtranspose: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = channels + self.in_channels = in_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + + reversed_block_out_channels = list(reversed(channels)) + + blocks: List[nn.Module] = [] + + # Initial convolution + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=reversed_block_out_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) + + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + block_out_ch = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + block_in_ch = block_out_ch + block_out_ch = reversed_block_out_channels[i] + is_final_block = i == len(channels) - 1 + + for _ in range(reversed_num_res_blocks[i]): + blocks.append( + _ResBlock( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=block_out_ch, + ) + ) + block_in_ch = block_out_ch + + if reversed_attention_levels[i]: + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append( + _Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=use_convtranspose) + ) + + blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class AutoencoderKL(nn.Module): + """ + Autoencoder model with KL-regularized latent space based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResBlock) per level. + channels: number of output channels for each block. + attention_levels: sequence of levels to add attention. + latent_channels: latent embedding dimension. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. + with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + use_checkpoint: if True, use activation checkpoint to save memory. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int = 1, + out_channels: int = 1, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + latent_channels: int = 3, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + with_encoder_nonlocal_attn: bool = True, + with_decoder_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + use_checkpoint: bool = False, + use_convtranspose: bool = False, + ) -> None: + super().__init__() + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError("AutoencoderKL expects all num_channels being multiple of norm_num_groups") + + if len(channels) != len(attention_levels): + raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels") + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_res_blocks) != len(channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.encoder = Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=channels, + out_channels=latent_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_encoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + ) + self.decoder = Decoder( + spatial_dims=spatial_dims, + channels=channels, + in_channels=latent_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_decoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + use_convtranspose=use_convtranspose, + ) + self.quant_conv_mu = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.quant_conv_log_sigma = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.post_quant_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.latent_channels = latent_channels + self.use_checkpoint = use_checkpoint + + def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations. + + Args: + x: BxCx[SPATIAL DIMS] tensor + + """ + if self.use_checkpoint: + h = torch.utils.checkpoint.checkpoint(self.encoder, x, use_reentrant=False) + else: + h = self.encoder(x) + + z_mu = self.quant_conv_mu(h) + z_log_var = self.quant_conv_log_sigma(h) + z_log_var = torch.clamp(z_log_var, -30.0, 20.0) + z_sigma = torch.exp(z_log_var / 2) + + return z_mu, z_sigma + + def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor: + """ + From the mean and sigma representations resulting of encoding an image through the latent space, + obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and + adding the mean. + + Args: + z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image + z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image + + Returns: + sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE] + """ + eps = torch.randn_like(z_sigma) + z_vae = z_mu + eps * z_sigma + return z_vae + + def reconstruct(self, x: torch.Tensor) -> torch.Tensor: + """ + Encodes and decodes an input image. + + Args: + x: BxCx[SPATIAL DIMENSIONS] tensor. + + Returns: + reconstructed image, of the same shape as input + """ + z_mu, _ = self.encode(x) + reconstruction = self.decode(z_mu) + return reconstruction + + def decode(self, z: torch.Tensor) -> torch.Tensor: + """ + Based on a latent space sample, forwards it through the Decoder. + + Args: + z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE] + + Returns: + decoded image tensor + """ + z = self.post_quant_conv(z) + dec: torch.Tensor + if self.use_checkpoint: + dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False) + else: + dec = self.decoder(z) + return dec + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + reconstruction = self.decode(z) + return reconstruction, z_mu, z_sigma + + def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + return z + + def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor: + image = self.decode(z) + return image diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py new file mode 100644 index 0000000000..448f1e8e9a --- /dev/null +++ b/tests/test_autoencoderkl.py @@ -0,0 +1,276 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import AutoencoderKL +from tests.utils import SkipIfBeforePyTorchVersion + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": (1, 1, 2), + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16, 16), + (1, 1, 16, 16, 16), + (1, 4, 4, 4, 4), + ], +] + + +class TestAutoEncoderKL(unittest.TestCase): + @parameterized.expand(CASES) + def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape): + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + self.assertEqual(result[2].shape, expected_latent_shape) + + @parameterized.expand(CASES) + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_with_convtranspose_and_checkpointing( + self, input_param, input_shape, expected_shape, expected_latent_shape + ): + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + self.assertEqual(result[2].shape, expected_latent_shape) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_num_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_num_channels_not_same_size_of_num_res_blocks(self): + with self.assertRaises(ValueError): + AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=(8, 8), + norm_num_groups=16, + ) + + def test_shape_reconstruction(self): + input_param, input_shape, expected_shape, _ = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.reconstruct(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_reconstruction_with_convtranspose_and_checkpointing(self): + input_param, input_shape, expected_shape, _ = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.reconstruct(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_shape_encode(self): + input_param, input_shape, _, expected_latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_encode_with_convtranspose_and_checkpointing(self): + input_param, input_shape, _, expected_latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + def test_shape_sampling(self): + input_param, _, _, expected_latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_sampling_convtranspose_and_checkpointing(self): + input_param, _, _, expected_latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + def test_shape_decode(self): + input_param, expected_input_shape, _, latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_decode_convtranspose_and_checkpointing(self): + input_param, expected_input_shape, _, latent_shape = CASES[0] + input_param = input_param.copy() + input_param.update({"use_checkpoint": True, "use_convtranspose": True}) + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + + +if __name__ == "__main__": + unittest.main() From 20a57e7b53ebebf2d802f612272ad02ebc15fbdd Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 6 Dec 2023 22:36:50 -0500 Subject: [PATCH 24/67] 6676 port generative networks vqvae (#7285) Partially fixes https://github.com/Project-MONAI/MONAI/issues/6676 ### Description Implements the VQ-VAE network, including the vector quantizer block, from MONAI Generative. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu Signed-off-by: Mark Graham Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: KumoLiu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Mark Graham --- docs/source/networks.rst | 13 + monai/bundle/scripts.py | 2 +- monai/networks/layers/__init__.py | 1 + monai/networks/layers/vector_quantizer.py | 233 +++++++++++ monai/networks/nets/__init__.py | 1 + monai/networks/nets/autoencoderkl.py | 14 +- monai/networks/nets/vqvae.py | 466 ++++++++++++++++++++++ tests/test_vector_quantizer.py | 89 +++++ tests/test_vqvae.py | 274 +++++++++++++ 9 files changed, 1085 insertions(+), 8 deletions(-) create mode 100644 monai/networks/layers/vector_quantizer.py create mode 100644 monai/networks/nets/vqvae.py create mode 100644 tests/test_vector_quantizer.py create mode 100644 tests/test_vqvae.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index dbfdf35784..d8be26264b 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -258,6 +258,7 @@ N-Dim Fourier Transform .. autofunction:: monai.networks.blocks.fft_utils_t.fftshift .. autofunction:: monai.networks.blocks.fft_utils_t.ifftshift + Layers ------ @@ -408,6 +409,13 @@ Layers .. autoclass:: LLTM :members: +`Vector Quantizer` +~~~~~~~~~~~~~~~~~~ +.. autoclass:: monai.networks.layers.vector_quantizer.EMAQuantizer + :members: +.. autoclass:: monai.networks.layers.vector_quantizer.VectorQuantizer + :members: + `Utilities` ~~~~~~~~~~~ .. automodule:: monai.networks.layers.convutils @@ -728,6 +736,11 @@ Nets .. autoclass:: voxelmorph +`VQ-VAE` +~~~~~~~~ +.. autoclass:: VQVAE + :members: + Utilities --------- .. automodule:: monai.networks.utils diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 20a491e493..2565a3cf64 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -221,7 +221,7 @@ def _download_from_ngc( def _get_latest_bundle_version_monaihosting(name): url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" - full_url = f"{url}/{name}" + full_url = f"{url}/{name.lower()}" requests_get, has_requests = optional_import("requests", name="get") if has_requests: resp = requests_get(full_url) diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index d61ed57f7f..bd3e3af3af 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -37,4 +37,5 @@ ) from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer +from .vector_quantizer import EMAQuantizer, VectorQuantizer from .weight_init import _no_grad_trunc_normal_, trunc_normal_ diff --git a/monai/networks/layers/vector_quantizer.py b/monai/networks/layers/vector_quantizer.py new file mode 100644 index 0000000000..9c354e1009 --- /dev/null +++ b/monai/networks/layers/vector_quantizer.py @@ -0,0 +1,233 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Sequence, Tuple + +import torch +from torch import nn + +__all__ = ["VectorQuantizer", "EMAQuantizer"] + + +class EMAQuantizer(nn.Module): + """ + Vector Quantization module using Exponential Moving Average (EMA) to learn the codebook parameters based on Neural + Discrete Representation Learning by Oord et al. (https://arxiv.org/abs/1711.00937) and the official implementation + that can be found at https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py#L148 and commit + 58d9a2746493717a7c9252938da7efa6006f3739. + + This module is not compatible with TorchScript while working in a Distributed Data Parallelism Module. This is due + to lack of TorchScript support for torch.distributed module as per https://github.com/pytorch/pytorch/issues/41353 + on 22/10/2022. If you want to TorchScript your model, please turn set `ddp_sync` to False. + + Args: + spatial_dims: number of spatial dimensions of the input. + num_embeddings: number of atomic elements in the codebook. + embedding_dim: number of channels of the input and atomic elements. + commitment_cost: scaling factor of the MSE loss between input and its quantized version. Defaults to 0.25. + decay: EMA decay. Defaults to 0.99. + epsilon: epsilon value. Defaults to 1e-5. + embedding_init: initialization method for the codebook. Defaults to "normal". + ddp_sync: whether to synchronize the codebook across processes. Defaults to True. + """ + + def __init__( + self, + spatial_dims: int, + num_embeddings: int, + embedding_dim: int, + commitment_cost: float = 0.25, + decay: float = 0.99, + epsilon: float = 1e-5, + embedding_init: str = "normal", + ddp_sync: bool = True, + ): + super().__init__() + self.spatial_dims: int = spatial_dims + self.embedding_dim: int = embedding_dim + self.num_embeddings: int = num_embeddings + + assert self.spatial_dims in [2, 3], ValueError( + f"EMAQuantizer only supports 4D and 5D tensor inputs but received spatial dims {spatial_dims}." + ) + + self.embedding: torch.nn.Embedding = torch.nn.Embedding(self.num_embeddings, self.embedding_dim) + if embedding_init == "normal": + # Initialization is passed since the default one is normal inside the nn.Embedding + pass + elif embedding_init == "kaiming_uniform": + torch.nn.init.kaiming_uniform_(self.embedding.weight.data, mode="fan_in", nonlinearity="linear") + self.embedding.weight.requires_grad = False + + self.commitment_cost: float = commitment_cost + + self.register_buffer("ema_cluster_size", torch.zeros(self.num_embeddings)) + self.register_buffer("ema_w", self.embedding.weight.data.clone()) + # declare types for mypy + self.ema_cluster_size: torch.Tensor + self.ema_w: torch.Tensor + self.decay: float = decay + self.epsilon: float = epsilon + + self.ddp_sync: bool = ddp_sync + + # Precalculating required permutation shapes + self.flatten_permutation = [0] + list(range(2, self.spatial_dims + 2)) + [1] + self.quantization_permutation: Sequence[int] = [0, self.spatial_dims + 1] + list( + range(1, self.spatial_dims + 1) + ) + + def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss. + + Args: + inputs: Encoding space tensors of shape [B, C, H, W, D]. + + Returns: + torch.Tensor: Flatten version of the input of shape [B*H*W*D, C]. + torch.Tensor: One-hot representation of the quantization indices of shape [B*H*W*D, self.num_embeddings]. + torch.Tensor: Quantization indices of shape [B,H,W,D,1] + + """ + with torch.cuda.amp.autocast(enabled=False): + encoding_indices_view = list(inputs.shape) + del encoding_indices_view[1] + + inputs = inputs.float() + + # Converting to channel last format + flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim) + + # Calculate Euclidean distances + distances = ( + (flat_input**2).sum(dim=1, keepdim=True) + + (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True) + - 2 * torch.mm(flat_input, self.embedding.weight.t()) + ) + + # Mapping distances to indexes + encoding_indices = torch.max(-distances, dim=1)[1] + encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float() + + # Quantize and reshape + encoding_indices = encoding_indices.view(encoding_indices_view) + + return flat_input, encodings, encoding_indices + + def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: + """ + Given encoding indices of shape [B,D,H,W,1] embeds them in the quantized space + [B, D, H, W, self.embedding_dim] and reshapes them to [B, self.embedding_dim, D, H, W] to be fed to the + decoder. + + Args: + embedding_indices: Tensor in channel last format which holds indices referencing atomic + elements from self.embedding + + Returns: + torch.Tensor: Quantize space representation of encoding_indices in channel first format. + """ + with torch.cuda.amp.autocast(enabled=False): + embedding: torch.Tensor = ( + self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous() + ) + return embedding + + def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None: + """ + TorchScript does not support torch.distributed.all_reduce. This function is a bypassing trick based on the + example: https://pytorch.org/docs/stable/generated/torch.jit.unused.html#torch.jit.unused + + Args: + encodings_sum: The summation of one hot representation of what encoding was used for each + position. + dw: The multiplication of the one hot representation of what encoding was used for each + position with the flattened input. + + Returns: + None + """ + if self.ddp_sync and torch.distributed.is_initialized(): + torch.distributed.all_reduce(tensor=encodings_sum, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce(tensor=dw, op=torch.distributed.ReduceOp.SUM) + else: + pass + + def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + flat_input, encodings, encoding_indices = self.quantize(inputs) + quantized = self.embed(encoding_indices) + + # Use EMA to update the embedding vectors + if self.training: + with torch.no_grad(): + encodings_sum = encodings.sum(0) + dw = torch.mm(encodings.t(), flat_input) + + if self.ddp_sync: + self.distributed_synchronization(encodings_sum, dw) + + self.ema_cluster_size.data.mul_(self.decay).add_(torch.mul(encodings_sum, 1 - self.decay)) + + # Laplace smoothing of the cluster size + n = self.ema_cluster_size.sum() + weights = (self.ema_cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n + self.ema_w.data.mul_(self.decay).add_(torch.mul(dw, 1 - self.decay)) + self.embedding.weight.data.copy_(self.ema_w / weights.unsqueeze(1)) + + # Encoding Loss + loss = self.commitment_cost * torch.nn.functional.mse_loss(quantized.detach(), inputs) + + # Straight Through Estimator + quantized = inputs + (quantized - inputs).detach() + + return quantized, loss, encoding_indices + + +class VectorQuantizer(torch.nn.Module): + """ + Vector Quantization wrapper that is needed as a workaround for the AMP to isolate the non fp16 compatible parts of + the quantization in their own class. + + Args: + quantizer (torch.nn.Module): Quantizer module that needs to return its quantized representation, loss and index + based quantized representation. + """ + + def __init__(self, quantizer: EMAQuantizer): + super().__init__() + + self.quantizer: EMAQuantizer = quantizer + + self.perplexity: torch.Tensor = torch.rand(1) + + def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + quantized, loss, encoding_indices = self.quantizer(inputs) + # Perplexity calculations + avg_probs = ( + torch.histc(encoding_indices.float(), bins=self.quantizer.num_embeddings, max=self.quantizer.num_embeddings) + .float() + .div(encoding_indices.numel()) + ) + + self.perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + return loss, quantized + + def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: + return self.quantizer.embed(embedding_indices=embedding_indices) + + def quantize(self, encodings: torch.Tensor) -> torch.Tensor: + output = self.quantizer(encodings) + encoding_indices: torch.Tensor = output[2] + return encoding_indices diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index ea08246d25..db3c77c717 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -113,3 +113,4 @@ from .vitautoenc import ViTAutoEnc from .vnet import VNet from .voxelmorph import VoxelMorph, VoxelMorphUNet +from .vqvae import VQVAE diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 9a9f35d5ae..f7ae77f056 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -38,7 +38,7 @@ class _Upsample(nn.Module): Convolution-based upsampling layer. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels to the layer. use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. """ @@ -98,7 +98,7 @@ class _Downsample(nn.Module): Convolution-based downsampling layer. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels. """ @@ -132,7 +132,7 @@ class _ResBlock(nn.Module): residual connection between input and output. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: input channels to the layer. norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of channels is divisible by this number. @@ -206,7 +206,7 @@ class _AttentionBlock(nn.Module): Attention block. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. num_channels: number of input channels. num_head_channels: number of channels in each attention head. norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of @@ -325,7 +325,7 @@ class Encoder(nn.Module): Convolutional cascade that downsamples the image into a spatial latent space. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels. channels: sequence of block output channels. out_channels: number of channels in the bottom layer (latent space) of the autoencoder. @@ -463,7 +463,7 @@ class Decoder(nn.Module): Convolutional cascade upsampling from a spatial latent space into an image space. Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. channels: sequence of block output channels. in_channels: number of channels in the bottom layer (latent space) of the autoencoder. out_channels: number of output channels. @@ -611,7 +611,7 @@ class AutoencoderKL(nn.Module): and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. in_channels: number of input channels. out_channels: number of output channels. num_res_blocks: number of residual blocks (see _ResBlock) per level. diff --git a/monai/networks/nets/vqvae.py b/monai/networks/nets/vqvae.py new file mode 100644 index 0000000000..d4771e203a --- /dev/null +++ b/monai/networks/nets/vqvae.py @@ -0,0 +1,466 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Tuple + +import torch +import torch.nn as nn + +from monai.networks.blocks import Convolution +from monai.networks.layers import Act +from monai.networks.layers.vector_quantizer import EMAQuantizer, VectorQuantizer +from monai.utils import ensure_tuple_rep + +__all__ = ["VQVAE"] + + +class VQVAEResidualUnit(nn.Module): + """ + Implementation of the ResidualLayer used in the VQVAE network as originally used in Morphology-preserving + Autoregressive 3D Generative Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf). + + The original implementation that can be found at + https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L150. + + Args: + spatial_dims: number of spatial spatial_dims of the input data. + in_channels: number of input channels. + num_res_channels: number of channels in the residual layers. + act: activation type and arguments. Defaults to RELU. + dropout: dropout ratio. Defaults to no dropout. + bias: whether to have a bias term. Defaults to True. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_res_channels: int, + act: tuple | str | None = Act.RELU, + dropout: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.num_res_channels = num_res_channels + self.act = act + self.dropout = dropout + self.bias = bias + + self.conv1 = Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.in_channels, + out_channels=self.num_res_channels, + adn_ordering="DA", + act=self.act, + dropout=self.dropout, + bias=self.bias, + ) + + self.conv2 = Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.num_res_channels, + out_channels=self.in_channels, + bias=self.bias, + conv_only=True, + ) + + def forward(self, x): + return torch.nn.functional.relu(x + self.conv2(self.conv1(x)), True) + + +class Encoder(nn.Module): + """ + Encoder module for VQ-VAE. + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of input channels. + out_channels: number of channels in the latent space (embedding_dim). + channels: sequence containing the number of channels at each level of the encoder. + num_res_layers: number of sequential residual layers at each level. + num_res_channels: number of channels in the residual layers at each level. + downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int) and padding (int). + dropout: dropout ratio. + act: activation type and arguments. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int], + num_res_layers: int, + num_res_channels: Sequence[int], + downsample_parameters: Sequence[Tuple[int, int, int, int]], + dropout: float, + act: tuple | str | None, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + self.downsample_parameters = downsample_parameters + self.dropout = dropout + self.act = act + + blocks: list[nn.Module] = [] + + for i in range(len(self.channels)): + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.in_channels if i == 0 else self.channels[i - 1], + out_channels=self.channels[i], + strides=self.downsample_parameters[i][0], + kernel_size=self.downsample_parameters[i][1], + adn_ordering="DA", + act=self.act, + dropout=None if i == 0 else self.dropout, + dropout_dim=1, + dilation=self.downsample_parameters[i][2], + padding=self.downsample_parameters[i][3], + ) + ) + + for _ in range(self.num_res_layers): + blocks.append( + VQVAEResidualUnit( + spatial_dims=self.spatial_dims, + in_channels=self.channels[i], + num_res_channels=self.num_res_channels[i], + act=self.act, + dropout=self.dropout, + ) + ) + + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.channels[len(self.channels) - 1], + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class Decoder(nn.Module): + """ + Decoder module for VQ-VAE. + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of channels in the latent space (embedding_dim). + out_channels: number of output channels. + channels: sequence containing the number of channels at each level of the decoder. + num_res_layers: number of sequential residual layers at each level. + num_res_channels: number of channels in the residual layers at each level. + upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). + dropout: dropout ratio. + act: activation type and arguments. + output_act: activation type and arguments for the output. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int], + num_res_layers: int, + num_res_channels: Sequence[int], + upsample_parameters: Sequence[Tuple[int, int, int, int, int]], + dropout: float, + act: tuple | str | None, + output_act: tuple | str | None, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + self.upsample_parameters = upsample_parameters + self.dropout = dropout + self.act = act + self.output_act = output_act + + reversed_num_channels = list(reversed(self.channels)) + + blocks: list[nn.Module] = [] + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=self.in_channels, + out_channels=reversed_num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + reversed_num_res_channels = list(reversed(self.num_res_channels)) + for i in range(len(self.channels)): + for _ in range(self.num_res_layers): + blocks.append( + VQVAEResidualUnit( + spatial_dims=self.spatial_dims, + in_channels=reversed_num_channels[i], + num_res_channels=reversed_num_res_channels[i], + act=self.act, + dropout=self.dropout, + ) + ) + + blocks.append( + Convolution( + spatial_dims=self.spatial_dims, + in_channels=reversed_num_channels[i], + out_channels=self.out_channels if i == len(self.channels) - 1 else reversed_num_channels[i + 1], + strides=self.upsample_parameters[i][0], + kernel_size=self.upsample_parameters[i][1], + adn_ordering="DA", + act=self.act, + dropout=self.dropout if i != len(self.channels) - 1 else None, + norm=None, + dilation=self.upsample_parameters[i][2], + conv_only=i == len(self.channels) - 1, + is_transposed=True, + padding=self.upsample_parameters[i][3], + output_padding=self.upsample_parameters[i][4], + ) + ) + + if self.output_act: + blocks.append(Act[self.output_act]()) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + return x + + +class VQVAE(nn.Module): + """ + Vector-Quantised Variational Autoencoder (VQ-VAE) used in Morphology-preserving Autoregressive 3D Generative + Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf) + + The original implementation can be found at + https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L163/ + + Args: + spatial_dims: number of spatial spatial_dims. + in_channels: number of input channels. + out_channels: number of output channels. + downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int) and padding (int). + upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the + following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int). + num_res_layers: number of sequential residual layers at each level. + channels: number of channels at each level. + num_res_channels: number of channels in the residual layers at each level. + num_embeddings: VectorQuantization number of atomic elements in the codebook. + embedding_dim: VectorQuantization number of channels of the input and atomic elements. + commitment_cost: VectorQuantization commitment_cost. + decay: VectorQuantization decay. + epsilon: VectorQuantization epsilon. + act: activation type and arguments. + dropout: dropout ratio. + output_act: activation type and arguments for the output. + ddp_sync: whether to synchronize the codebook across processes. + use_checkpointing if True, use activation checkpointing to save memory. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int] = (96, 96, 192), + num_res_layers: int = 3, + num_res_channels: Sequence[int] | int = (96, 96, 192), + downsample_parameters: Sequence[Tuple[int, int, int, int]] + | Tuple[int, int, int, int] = ((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1)), + upsample_parameters: Sequence[Tuple[int, int, int, int, int]] + | Tuple[int, int, int, int, int] = ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + num_embeddings: int = 32, + embedding_dim: int = 64, + embedding_init: str = "normal", + commitment_cost: float = 0.25, + decay: float = 0.5, + epsilon: float = 1e-5, + dropout: float = 0.0, + act: tuple | str | None = Act.RELU, + output_act: tuple | str | None = None, + ddp_sync: bool = True, + use_checkpointing: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.spatial_dims = spatial_dims + self.channels = channels + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.use_checkpointing = use_checkpointing + + if isinstance(num_res_channels, int): + num_res_channels = ensure_tuple_rep(num_res_channels, len(channels)) + + if len(num_res_channels) != len(channels): + raise ValueError( + "`num_res_channels` should be a single integer or a tuple of integers with the same length as " + "`num_channls`." + ) + if all(isinstance(values, int) for values in upsample_parameters): + upsample_parameters_tuple: Sequence = (upsample_parameters,) * len(channels) + else: + upsample_parameters_tuple = upsample_parameters + + if all(isinstance(values, int) for values in downsample_parameters): + downsample_parameters_tuple: Sequence = (downsample_parameters,) * len(channels) + else: + downsample_parameters_tuple = downsample_parameters + + if not all(all(isinstance(value, int) for value in sub_item) for sub_item in downsample_parameters_tuple): + raise ValueError("`downsample_parameters` should be a single tuple of integer or a tuple of tuples.") + + # check if downsample_parameters is a tuple of ints or a tuple of tuples of ints + if not all(all(isinstance(value, int) for value in sub_item) for sub_item in upsample_parameters_tuple): + raise ValueError("`upsample_parameters` should be a single tuple of integer or a tuple of tuples.") + + for parameter in downsample_parameters_tuple: + if len(parameter) != 4: + raise ValueError("`downsample_parameters` should be a tuple of tuples with 4 integers.") + + for parameter in upsample_parameters_tuple: + if len(parameter) != 5: + raise ValueError("`upsample_parameters` should be a tuple of tuples with 5 integers.") + + if len(downsample_parameters_tuple) != len(channels): + raise ValueError( + "`downsample_parameters` should be a tuple of tuples with the same length as `num_channels`." + ) + + if len(upsample_parameters_tuple) != len(channels): + raise ValueError( + "`upsample_parameters` should be a tuple of tuples with the same length as `num_channels`." + ) + + self.num_res_layers = num_res_layers + self.num_res_channels = num_res_channels + + self.encoder = Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=embedding_dim, + channels=channels, + num_res_layers=num_res_layers, + num_res_channels=num_res_channels, + downsample_parameters=downsample_parameters_tuple, + dropout=dropout, + act=act, + ) + + self.decoder = Decoder( + spatial_dims=spatial_dims, + in_channels=embedding_dim, + out_channels=out_channels, + channels=channels, + num_res_layers=num_res_layers, + num_res_channels=num_res_channels, + upsample_parameters=upsample_parameters_tuple, + dropout=dropout, + act=act, + output_act=output_act, + ) + + self.quantizer = VectorQuantizer( + quantizer=EMAQuantizer( + spatial_dims=spatial_dims, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + commitment_cost=commitment_cost, + decay=decay, + epsilon=epsilon, + embedding_init=embedding_init, + ddp_sync=ddp_sync, + ) + ) + + def encode(self, images: torch.Tensor) -> torch.Tensor: + output: torch.Tensor + if self.use_checkpointing: + output = torch.utils.checkpoint.checkpoint(self.encoder, images, use_reentrant=False) + else: + output = self.encoder(images) + return output + + def quantize(self, encodings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + x_loss, x = self.quantizer(encodings) + return x, x_loss + + def decode(self, quantizations: torch.Tensor) -> torch.Tensor: + output: torch.Tensor + + if self.use_checkpointing: + output = torch.utils.checkpoint.checkpoint(self.decoder, quantizations, use_reentrant=False) + else: + output = self.decoder(quantizations) + return output + + def index_quantize(self, images: torch.Tensor) -> torch.Tensor: + return self.quantizer.quantize(self.encode(images=images)) + + def decode_samples(self, embedding_indices: torch.Tensor) -> torch.Tensor: + return self.decode(self.quantizer.embed(embedding_indices)) + + def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + quantizations, quantization_losses = self.quantize(self.encode(images)) + reconstruction = self.decode(quantizations) + + return reconstruction, quantization_losses + + def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + z = self.encode(x) + e, _ = self.quantize(z) + return e + + def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor: + e, _ = self.quantize(z) + image = self.decode(e) + return image diff --git a/tests/test_vector_quantizer.py b/tests/test_vector_quantizer.py new file mode 100644 index 0000000000..43533d0377 --- /dev/null +++ b/tests/test_vector_quantizer.py @@ -0,0 +1,89 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from math import prod + +import torch +from parameterized import parameterized + +from monai.networks.layers import EMAQuantizer, VectorQuantizer + +TEST_CASES = [ + [{"spatial_dims": 2, "num_embeddings": 16, "embedding_dim": 8}, (1, 8, 4, 4), (1, 4, 4)], + [{"spatial_dims": 3, "num_embeddings": 16, "embedding_dim": 8}, (1, 8, 4, 4, 4), (1, 4, 4, 4)], +] + + +class TestEMA(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_ema_shape(self, input_param, input_shape, output_shape): + layer = EMAQuantizer(**input_param) + x = torch.randn(input_shape) + layer = layer.train() + outputs = layer(x) + self.assertEqual(outputs[0].shape, input_shape) + self.assertEqual(outputs[2].shape, output_shape) + + layer = layer.eval() + outputs = layer(x) + self.assertEqual(outputs[0].shape, input_shape) + self.assertEqual(outputs[2].shape, output_shape) + + @parameterized.expand(TEST_CASES) + def test_ema_quantize(self, input_param, input_shape, output_shape): + layer = EMAQuantizer(**input_param) + x = torch.randn(input_shape) + outputs = layer.quantize(x) + self.assertEqual(outputs[0].shape, (prod(input_shape[2:]), input_shape[1])) # (HxW[xD], C) + self.assertEqual(outputs[1].shape, (prod(input_shape[2:]), input_param["num_embeddings"])) # (HxW[xD], E) + self.assertEqual(outputs[2].shape, (input_shape[0],) + input_shape[2:]) # (1, H, W, [D]) + + def test_ema(self): + layer = EMAQuantizer(spatial_dims=2, num_embeddings=2, embedding_dim=2, epsilon=0, decay=0) + original_weight_0 = layer.embedding.weight[0].clone() + original_weight_1 = layer.embedding.weight[1].clone() + x_0 = original_weight_0 + x_0 = x_0.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + x_0 = x_0.repeat(1, 1, 1, 2) + 0.001 + + x_1 = original_weight_1 + x_1 = x_1.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + x_1 = x_1.repeat(1, 1, 1, 2) + + x = torch.cat([x_0, x_1], dim=0) + layer = layer.train() + _ = layer(x) + + self.assertTrue(all(layer.embedding.weight[0] != original_weight_0)) + self.assertTrue(all(layer.embedding.weight[1] == original_weight_1)) + + +class TestVectorQuantizer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_vector_quantizer_shape(self, input_param, input_shape, output_shape): + layer = VectorQuantizer(EMAQuantizer(**input_param)) + x = torch.randn(input_shape) + outputs = layer(x) + self.assertEqual(outputs[1].shape, input_shape) + + @parameterized.expand(TEST_CASES) + def test_vector_quantizer_quantize(self, input_param, input_shape, output_shape): + layer = VectorQuantizer(EMAQuantizer(**input_param)) + x = torch.randn(input_shape) + outputs = layer.quantize(x) + self.assertEqual(outputs.shape, output_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vqvae.py b/tests/test_vqvae.py new file mode 100644 index 0000000000..4916dc2faa --- /dev/null +++ b/tests/test_vqvae.py @@ -0,0 +1,274 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.vqvae import VQVAE +from tests.utils import SkipIfBeforePyTorchVersion + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": (4, 4), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8), + (1, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": 4, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8, 8), + (1, 1, 8, 8, 8), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": (4, 4), + "downsample_parameters": (2, 4, 1, 1), + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8), + (1, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "num_res_layers": 1, + "num_res_channels": (4, 4), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": (2, 4, 1, 1, 0), + "num_embeddings": 8, + "embedding_dim": 8, + }, + (1, 1, 8, 8, 8), + (1, 1, 8, 8, 8), + ], +] + +TEST_LATENT_SHAPE = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "channels": (8, 8), + "num_res_channels": (8, 8), + "num_embeddings": 16, + "embedding_dim": 8, +} + + +class TestVQVAE(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape, expected_shape): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**input_param).to(device) + + with eval_mode(net): + result, _ = net(torch.randn(input_shape).to(device)) + + self.assertEqual(result.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + @SkipIfBeforePyTorchVersion((1, 11)) + def test_shape_with_checkpoint(self, input_param, input_shape, expected_shape): + device = "cuda" if torch.cuda.is_available() else "cpu" + input_param = input_param.copy() + input_param.update({"use_checkpointing": True}) + + net = VQVAE(**input_param).to(device) + + with eval_mode(net): + result, _ = net(torch.randn(input_shape).to(device)) + + self.assertEqual(result.shape, expected_shape) + + # Removed this test case since TorchScript currently does not support activation checkpoint. + # def test_script(self): + # net = VQVAE( + # spatial_dims=2, + # in_channels=1, + # out_channels=1, + # downsample_parameters=((2, 4, 1, 1),) * 2, + # upsample_parameters=((2, 4, 1, 1, 0),) * 2, + # num_res_layers=1, + # channels=(8, 8), + # num_res_channels=(8, 8), + # num_embeddings=16, + # embedding_dim=8, + # ddp_sync=False, + # ) + # test_data = torch.randn(1, 1, 16, 16) + # test_script_save(net, test_data) + + def test_channels_not_same_size_of_num_res_channels(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_channels_not_same_size_of_downsample_parameters(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1, 1),) * 3, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_channels_not_same_size_of_upsample_parameters(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 3, + ) + + def test_downsample_parameters_not_sequence_or_int(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=(("test", 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_upsample_parameters_not_sequence_or_int(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=(("test", 4, 1, 1, 0),) * 2, + ) + + def test_downsample_parameter_length_different_4(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16), + downsample_parameters=((2, 4, 1),) * 3, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + ) + + def test_upsample_parameter_length_different_5(self): + with self.assertRaises(ValueError): + VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(16, 16), + num_res_channels=(16, 16, 16), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0, 1),) * 3, + ) + + def test_encode_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.encode(torch.randn(1, 1, 32, 32).to(device)) + + self.assertEqual(latent.shape, (1, 8, 8, 8)) + + def test_index_quantize_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.index_quantize(torch.randn(1, 1, 32, 32).to(device)) + + self.assertEqual(latent.shape, (1, 8, 8)) + + def test_decode_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.decode(torch.randn(1, 8, 8, 8).to(device)) + + self.assertEqual(latent.shape, (1, 1, 32, 32)) + + def test_decode_samples_shape(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + + net = VQVAE(**TEST_LATENT_SHAPE).to(device) + + with eval_mode(net): + latent = net.decode_samples(torch.randint(low=0, high=16, size=(1, 8, 8)).to(device)) + + self.assertEqual(latent.shape, (1, 1, 32, 32)) + + +if __name__ == "__main__": + unittest.main() From 5e2466d675d56ccb7ce503e7a64bb85cc9b8ab54 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 11 Dec 2023 11:10:20 -0500 Subject: [PATCH 25/67] 6676 port generative networks transformer (#7300) Towards #6676 . ### Description Adds a simple decoder-only transformer architecture. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham --- docs/source/networks.rst | 5 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/transformer.py | 314 +++++++++++++++++++++++++++++ tests/test_transformer.py | 73 +++++++ 4 files changed, 393 insertions(+) create mode 100644 monai/networks/nets/transformer.py create mode 100644 tests/test_transformer.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index d8be26264b..06f60fe8af 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -613,6 +613,11 @@ Nets .. autoclass:: VarAutoEncoder :members: +`DecoderOnlyTransformer` +~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: DecoderOnlyTransformer + :members: + `ViT` ~~~~~ .. autoclass:: ViT diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index db3c77c717..08384b4d52 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -106,6 +106,7 @@ from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR from .torchvision_fc import TorchVisionFCModel from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex +from .transformer import DecoderOnlyTransformer from .unet import UNet, Unet from .unetr import UNETR from .varautoencoder import VarAutoEncoder diff --git a/monai/networks/nets/transformer.py b/monai/networks/nets/transformer.py new file mode 100644 index 0000000000..b742c12205 --- /dev/null +++ b/monai/networks/nets/transformer.py @@ -0,0 +1,314 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks.mlp import MLPBlock +from monai.utils import optional_import + +xops, has_xformers = optional_import("xformers.ops") +__all__ = ["DecoderOnlyTransformer"] + + +class _SABlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + A self-attention block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + Args: + hidden_size: dimension of hidden layer. + num_heads: number of attention heads. + dropout_rate: dropout ratio. Defaults to no dropout. + qkv_bias: bias term for the qkv linear layer. + causal: whether to use causal attention. + sequence_length: if causal is True, it is necessary to specify the sequence length. + with_cross_attention: Whether to use cross attention for conditioning. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout_rate: float = 0.0, + qkv_bias: bool = False, + causal: bool = False, + sequence_length: int | None = None, + with_cross_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = 1.0 / math.sqrt(self.head_dim) + self.causal = causal + self.sequence_length = sequence_length + self.with_cross_attention = with_cross_attention + self.use_flash_attention = use_flash_attention + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + self.dropout_rate = dropout_rate + + if hidden_size % num_heads != 0: + raise ValueError("hidden size should be divisible by num_heads.") + + if causal and sequence_length is None: + raise ValueError("sequence_length is necessary for causal attention.") + + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + + # key, query, value projections + self.to_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + self.to_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + self.to_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + + # regularization + self.drop_weights = nn.Dropout(dropout_rate) + self.drop_output = nn.Dropout(dropout_rate) + + # output projection + self.out_proj = nn.Linear(hidden_size, hidden_size) + + if causal and sequence_length is not None: + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "causal_mask", + torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), + ) + self.causal_mask: torch.Tensor + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + query = self.to_q(x) + + kv = context if context is not None else x + _, kv_t, _ = kv.size() + key = self.to_k(kv) + value = self.to_v(kv) + + query = query.view(b, t, self.num_heads, c // self.num_heads) # (b, t, nh, hs) + key = key.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) + value = value.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) + y: torch.Tensor + if self.use_flash_attention: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + y = xops.memory_efficient_attention( + query=query, + key=key, + value=value, + scale=self.scale, + p=self.dropout_rate, + attn_bias=xops.LowerTriangularMask() if self.causal else None, + ) + + else: + query = query.transpose(1, 2) # (b, nh, t, hs) + key = key.transpose(1, 2) # (b, nh, kv_t, hs) + value = value.transpose(1, 2) # (b, nh, kv_t, hs) + + # manual implementation of attention + query = query * self.scale + attention_scores = query @ key.transpose(-2, -1) + + if self.causal: + attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.drop_weights(attention_probs) + y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs) + + y = y.transpose(1, 2) # (b, nh, t, hs) -> (b, t, nh, hs) + + y = y.contiguous().view(b, t, c) # re-assemble all head outputs side by side + + y = self.out_proj(y) + y = self.drop_output(y) + return y + + +class _TransformerBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + A transformer block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + Args: + hidden_size: dimension of hidden layer. + mlp_dim: dimension of feedforward layer. + num_heads: number of attention heads. + dropout_rate: faction of the input units to drop. + qkv_bias: apply bias term for the qkv linear layer + causal: whether to use causal attention. + sequence_length: if causal is True, it is necessary to specify the sequence length. + with_cross_attention: Whether to use cross attention for conditioning. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + hidden_size: int, + mlp_dim: int, + num_heads: int, + dropout_rate: float = 0.0, + qkv_bias: bool = False, + causal: bool = False, + sequence_length: int | None = None, + with_cross_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + self.with_cross_attention = with_cross_attention + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + raise ValueError("hidden_size should be divisible by num_heads.") + + self.norm1 = nn.LayerNorm(hidden_size) + self.attn = _SABlock( + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + causal=causal, + sequence_length=sequence_length, + use_flash_attention=use_flash_attention, + ) + + if self.with_cross_attention: + self.norm2 = nn.LayerNorm(hidden_size) + self.cross_attn = _SABlock( + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + with_cross_attention=with_cross_attention, + causal=False, + use_flash_attention=use_flash_attention, + ) + self.norm3 = nn.LayerNorm(hidden_size) + self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + x = x + self.attn(self.norm1(x)) + if self.with_cross_attention: + x = x + self.cross_attn(self.norm2(x), context=context) + x = x + self.mlp(self.norm3(x)) + return x + + +class AbsolutePositionalEmbedding(nn.Module): + """Absolute positional embedding. + + Args: + max_seq_len: Maximum sequence length. + embedding_dim: Dimensionality of the embedding. + """ + + def __init__(self, max_seq_len: int, embedding_dim: int) -> None: + super().__init__() + self.max_seq_len = max_seq_len + self.embedding_dim = embedding_dim + self.embedding = nn.Embedding(max_seq_len, embedding_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len = x.size() + positions = torch.arange(seq_len, device=x.device).repeat(batch_size, 1) + embedding: torch.Tensor = self.embedding(positions) + return embedding + + +class DecoderOnlyTransformer(nn.Module): + """Decoder-only (Autoregressive) Transformer model. + + Args: + num_tokens: Number of tokens in the vocabulary. + max_seq_len: Maximum sequence length. + attn_layers_dim: Dimensionality of the attention layers. + attn_layers_depth: Number of attention layers. + attn_layers_heads: Number of attention heads. + with_cross_attention: Whether to use cross attention for conditioning. + embedding_dropout_rate: Dropout rate for the embedding. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + num_tokens: int, + max_seq_len: int, + attn_layers_dim: int, + attn_layers_depth: int, + attn_layers_heads: int, + with_cross_attention: bool = False, + embedding_dropout_rate: float = 0.0, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.num_tokens = num_tokens + self.max_seq_len = max_seq_len + self.attn_layers_dim = attn_layers_dim + self.attn_layers_depth = attn_layers_depth + self.attn_layers_heads = attn_layers_heads + self.with_cross_attention = with_cross_attention + + self.token_embeddings = nn.Embedding(num_tokens, attn_layers_dim) + self.position_embeddings = AbsolutePositionalEmbedding(max_seq_len=max_seq_len, embedding_dim=attn_layers_dim) + self.embedding_dropout = nn.Dropout(embedding_dropout_rate) + + self.blocks = nn.ModuleList( + [ + _TransformerBlock( + hidden_size=attn_layers_dim, + mlp_dim=attn_layers_dim * 4, + num_heads=attn_layers_heads, + dropout_rate=0.0, + qkv_bias=False, + causal=True, + sequence_length=max_seq_len, + with_cross_attention=with_cross_attention, + use_flash_attention=use_flash_attention, + ) + for _ in range(attn_layers_depth) + ] + ) + + self.to_logits = nn.Linear(attn_layers_dim, num_tokens) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + tok_emb = self.token_embeddings(x) + pos_emb = self.position_embeddings(x) + x = self.embedding_dropout(tok_emb + pos_emb) + + for block in self.blocks: + x = block(x, context=context) + logits: torch.Tensor = self.to_logits(x) + return logits diff --git a/tests/test_transformer.py b/tests/test_transformer.py new file mode 100644 index 0000000000..ea6ebdf50f --- /dev/null +++ b/tests/test_transformer.py @@ -0,0 +1,73 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import DecoderOnlyTransformer + +TEST_CASES = [] +for dropout_rate in np.linspace(0, 1, 2): + for attention_layer_dim in [360, 480, 600, 768]: + for num_heads in [4, 6, 8, 12]: + TEST_CASES.append( + [ + { + "num_tokens": 10, + "max_seq_len": 16, + "attn_layers_dim": attention_layer_dim, + "attn_layers_depth": 2, + "attn_layers_heads": num_heads, + "embedding_dropout_rate": dropout_rate, + } + ] + ) + + +class TestDecoderOnlyTransformer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_unconditioned_models(self, input_param): + net = DecoderOnlyTransformer(**input_param) + with eval_mode(net): + net.forward(torch.randint(0, 10, (1, 16))) + + @parameterized.expand(TEST_CASES) + def test_conditioned_models(self, input_param): + net = DecoderOnlyTransformer(**input_param, with_cross_attention=True) + with eval_mode(net): + net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 3, input_param["attn_layers_dim"])) + + def test_attention_dim_not_multiple_of_heads(self): + with self.assertRaises(ValueError): + DecoderOnlyTransformer( + num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=3 + ) + + def test_dropout_rate_negative(self): + with self.assertRaises(ValueError): + DecoderOnlyTransformer( + num_tokens=10, + max_seq_len=16, + attn_layers_dim=8, + attn_layers_depth=2, + attn_layers_heads=2, + embedding_dropout_rate=-1, + ) + + +if __name__ == "__main__": + unittest.main() From 249c39aa1b40159e3c18f2449000f0262d83ba80 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 12 Dec 2023 02:09:37 -0500 Subject: [PATCH 26/67] 6676 port generative networks ddpm (#7304) Towards #6676 . ### Description Adds a DDPM unet. Refactoring for some of the blocks here is scheduled [here](https://github.com/Project-MONAI/MONAI/issues/7227). ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- docs/source/networks.rst | 5 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/diffusion_model_unet.py | 2138 +++++++++++++++++++ tests/test_diffusion_model_unet.py | 535 +++++ 4 files changed, 2679 insertions(+) create mode 100644 monai/networks/nets/diffusion_model_unet.py create mode 100644 tests/test_diffusion_model_unet.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 06f60fe8af..417fb8ac73 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -583,6 +583,11 @@ Nets .. autoclass:: VNet :members: +`DiffusionModelUnet` +~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: DiffusionModelUNet + :members: + `RegUNet` ~~~~~~~~~ .. autoclass:: RegUNet diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 08384b4d52..31fbd73b4e 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -35,6 +35,7 @@ densenet201, densenet264, ) +from .diffusion_model_unet import DiffusionModelUNet from .dints import DiNTS, TopologyConstruction, TopologyInstance, TopologySearch from .dynunet import DynUNet, DynUnet, Dynunet from .efficientnet import ( diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py new file mode 100644 index 0000000000..1532215c70 --- /dev/null +++ b/monai/networks/nets/diffusion_model_unet.py @@ -0,0 +1,2138 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import math +from collections.abc import Sequence + +import torch +import torch.nn.functional as F +from torch import nn + +from monai.networks.blocks import Convolution, MLPBlock +from monai.networks.layers.factories import Pool +from monai.utils import ensure_tuple_rep, optional_import + +# To install xformers, use pip install xformers==0.0.16rc401 + +xops, has_xformers = optional_import("xformers.ops") + + +__all__ = ["DiffusionModelUNet"] + + +def zero_module(module: nn.Module) -> nn.Module: + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class _CrossAttention(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + A cross attention layer. + + Args: + query_dim: number of channels in the query. + cross_attention_dim: number of channels in the context. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each head. + dropout: dropout probability to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: int | None = None, + num_attention_heads: int = 8, + num_head_channels: int = 64, + dropout: float = 0.0, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.use_flash_attention = use_flash_attention + inner_dim = num_head_channels * num_attention_heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + + self.scale = 1 / math.sqrt(num_head_channels) + self.num_heads = num_attention_heads + + self.upcast_attention = upcast_attention + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + """ + Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. + """ + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + """Combine the output of the attention heads back into the hidden state dimension.""" + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x: torch.Tensor = xops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs.to(dtype=dtype) + + x = torch.bmm(attention_probs, value) + return x + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + query = self.to_q(x) + context = context if context is not None else x + key = self.to_k(context) + value = self.to_v(context) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + if self.use_flash_attention: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + output: torch.Tensor = self.to_out(x) + return output + + +class _BasicTransformerBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + A basic Transformer block. + + Args: + num_channels: number of channels in the input and output. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each attention head. + dropout: dropout probability to use. + cross_attention_dim: size of the context vector for cross attention. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + num_channels: int, + num_attention_heads: int, + num_head_channels: int, + dropout: float = 0.0, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.attn1 = _CrossAttention( + query_dim=num_channels, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) # is a self-attention + self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) + self.attn2 = _CrossAttention( + query_dim=num_channels, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) # is a self-attention if context is None + self.norm1 = nn.LayerNorm(num_channels) + self.norm2 = nn.LayerNorm(num_channels) + self.norm3 = nn.LayerNorm(num_channels) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + # 1. Self-Attention + x = self.attn1(self.norm1(x)) + x + + # 2. Cross-Attention + x = self.attn2(self.norm2(x), context=context) + x + + # 3. Feed-forward + x = self.ff(self.norm3(x)) + x + return x + + +class _SpatialTransformer(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply + standard transformer action. Finally, reshape to image. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of channels in the input and output. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each attention head. + num_layers: number of layers of Transformer blocks to use. + dropout: dropout probability to use. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_attention_heads: int, + num_head_channels: int, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + inner_dim = num_attention_heads * num_head_channels + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + + self.proj_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=inner_dim, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + self.transformer_blocks = nn.ModuleList( + [ + _BasicTransformerBlock( + num_channels=inner_dim, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + for _ in range(num_layers) + ] + ) + + self.proj_out = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=inner_dim, + out_channels=in_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + ) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + # note: if no context is given, cross-attention defaults to self-attention + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + residual = x + x = self.norm(x) + x = self.proj_in(x) + + inner_dim = x.shape[1] + + if self.spatial_dims == 2: + x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + if self.spatial_dims == 3: + x = x.permute(0, 2, 3, 4, 1).reshape(batch, height * width * depth, inner_dim) + + for block in self.transformer_blocks: + x = block(x, context=context) + + if self.spatial_dims == 2: + x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + if self.spatial_dims == 3: + x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous() + + x = self.proj_out(x) + return x + residual + + +class _AttentionBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + An attention block that allows spatial positions to attend to each other. Uses three q, k, v linear layers to + compute attention. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + num_head_channels: number of channels in each attention head. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon value to use for the normalisation. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: int, + num_head_channels: int | None = None, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.use_flash_attention = use_flash_attention + self.spatial_dims = spatial_dims + self.num_channels = num_channels + + self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 + self.scale = 1 / math.sqrt(num_channels / self.num_heads) + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) + + self.to_q = nn.Linear(num_channels, num_channels) + self.to_k = nn.Linear(num_channels, num_channels) + self.to_v = nn.Linear(num_channels, num_channels) + + self.proj_attn = nn.Linear(num_channels, num_channels) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x: torch.Tensor = xops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + x = torch.bmm(attention_probs, value) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + # norm + x = self.norm(x) + + if self.spatial_dims == 2: + x = x.view(batch, channel, height * width).transpose(1, 2) + if self.spatial_dims == 3: + x = x.view(batch, channel, height * width * depth).transpose(1, 2) + + # proj to q, k, v + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if self.use_flash_attention: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + + if self.spatial_dims == 2: + x = x.transpose(-1, -2).reshape(batch, channel, height, width) + if self.spatial_dims == 3: + x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) + + return x + residual + + +def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor: + """ + Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic + Models" https://arxiv.org/abs/2006.11239. + + Args: + timesteps: a 1-D Tensor of N indices, one per batch element. + embedding_dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + """ + if timesteps.ndim != 1: + raise ValueError("Timesteps should be a 1d-array") + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + freqs = torch.exp(exponent / half_dim) + + args = timesteps[:, None].float() * freqs[None, :] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0)) + + return embedding + + +class _Downsample(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Downsampling layer. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is + False, the number of output channels must be the same as the number of input channels. + out_channels: number of output channels. + padding: controls the amount of implicit zero-paddings on both sides for padding number of points + for each dimension. + """ + + def __init__( + self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 + ) -> None: + super().__init__() + self.num_channels = num_channels + self.out_channels = out_channels or num_channels + self.use_conv = use_conv + if use_conv: + self.op = Convolution( + spatial_dims=spatial_dims, + in_channels=self.num_channels, + out_channels=self.out_channels, + strides=2, + kernel_size=3, + padding=padding, + conv_only=True, + ) + else: + if self.num_channels != self.out_channels: + raise ValueError("num_channels and out_channels must be equal when use_conv=False") + self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2) + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + del emb + if x.shape[1] != self.num_channels: + raise ValueError( + f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels " + f"({self.num_channels})" + ) + output: torch.Tensor = self.op(x) + return output + + +class _Upsample(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + + Upsampling layer with an optional convolution. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + use_conv: if True uses Convolution instead of Pool average to perform downsampling. + out_channels: number of output channels. + padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each + dimension. + """ + + def __init__( + self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 + ) -> None: + super().__init__() + self.num_channels = num_channels + self.out_channels = out_channels or num_channels + self.use_conv = use_conv + if use_conv: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=self.num_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=padding, + conv_only=True, + ) + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + del emb + if x.shape[1] != self.num_channels: + raise ValueError("Input channels should be equal to num_channels") + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # https://github.com/pytorch/pytorch/issues/86679 + dtype = x.dtype + if dtype == torch.bfloat16: + x = x.to(torch.float32) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + x = x.to(dtype) + + if self.use_conv: + x = self.conv(x) + return x + + +class _ResnetBlock(nn.Module): + """ + NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make + use of this block as support is not guaranteed. For more information see: + https://github.com/Project-MONAI/MONAI/issues/7227 + Residual block with timestep conditioning. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + out_channels: number of output channels. + up: if True, performs upsampling. + down: if True, performs downsampling. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + out_channels: int | None = None, + up: bool = False, + down: bool = False, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = in_channels + self.emb_channels = temb_channels + self.out_channels = out_channels or in_channels + self.up = up + self.down = down + + self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + self.nonlinearity = nn.SiLU() + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.upsample = self.downsample = None + if self.up: + self.upsample = _Upsample(spatial_dims, in_channels, use_conv=False) + elif down: + self.downsample = _Downsample(spatial_dims, in_channels, use_conv=False) + + self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) + + self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.out_channels, eps=norm_eps, affine=True) + self.conv2 = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + self.skip_connection: nn.Module + if self.out_channels == in_channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = self.nonlinearity(h) + + if self.upsample is not None: + if h.shape[0] >= 64: + x = x.contiguous() + h = h.contiguous() + x = self.upsample(x) + h = self.upsample(h) + elif self.downsample is not None: + x = self.downsample(x) + h = self.downsample(h) + + h = self.conv1(h) + + if self.spatial_dims == 2: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] + else: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] + h = h + temb + + h = self.norm2(h) + h = self.nonlinearity(h) + h = self.conv2(h) + output: torch.Tensor = self.skip_connection(x) + h + return output + + +class DownBlock(nn.Module): + """ + Unet's down block containing resnet and downsamplers blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsampler: nn.Module | None + if resblock_updown: + self.downsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = _Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + del context + output_states = [] + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class AttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.downsampler: nn.Module | None + if add_downsample: + if resblock_updown: + self.downsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = _Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + del context + output_states = [] + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class CrossAttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and cross-attention blocks. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + attentions.append( + _SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + num_layers=transformer_num_layers, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.downsampler: nn.Module | None + if add_downsample: + if resblock_updown: + self.downsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = _Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + output_states = [] + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=context) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class AttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + + self.resnet_1 = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + self.attention = _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=in_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + + self.resnet_2 = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> torch.Tensor: + del context + hidden_states = self.resnet_1(hidden_states, temb) + hidden_states = self.attention(hidden_states) + hidden_states = self.resnet_2(hidden_states, temb) + + return hidden_states + + +class CrossAttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and cross-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + + self.resnet_1 = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + self.attention = _SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=in_channels, + num_attention_heads=in_channels // num_head_channels, + num_head_channels=num_head_channels, + num_layers=transformer_num_layers, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + self.resnet_2 = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> torch.Tensor: + hidden_states = self.resnet_1(hidden_states, temb) + hidden_states = self.attention(hidden_states, context=context) + hidden_states = self.resnet_2(hidden_states, temb) + + return hidden_states + + +class UpBlock(nn.Module): + """ + Unet's up block containing resnet and upsamplers blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class AttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class CrossAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + _SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=context) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +def get_down_block( + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_downsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_attn: + return AttnDownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + elif with_cross_attn: + return CrossAttnDownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return DownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + ) + + +def get_mid_block( + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int, + norm_eps: float, + with_conditioning: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_conditioning: + return CrossAttnMidBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return AttnMidBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + + +def get_up_block( + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_upsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_attn: + return AttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + elif with_cross_attn: + return CrossAttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return UpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + ) + + +class DiffusionModelUNet(nn.Module): + """ + Unet network with timestep embedding and attention mechanisms for conditioning based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResnetBlock) per level. + channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + ) + if dropout_cattn > 1.0 or dropout_cattn < 0.0: + raise ValueError("Dropout cannot be negative or >1.0!") + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + + if len(channels) != len(attention_levels): + raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_res_blocks) != len(channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.in_channels = in_channels + self.block_out_channels = channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # down + self.down_blocks = nn.ModuleList([]) + output_channel = channels[0] + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_final_block = i == len(channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + self.down_blocks.append(down_block) + + # mid + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=channels[-1], + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + # up + self.up_blocks = nn.ModuleList([]) + reversed_block_out_channels = list(reversed(channels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_head_channels = list(reversed(num_head_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(channels) - 1)] + + is_final_block = i == len(channels) - 1 + + up_block = get_up_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + prev_output_channel=prev_output_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=reversed_num_res_blocks[i] + 1, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(reversed_attention_levels[i] and not with_conditioning), + with_cross_attn=(reversed_attention_levels[i] and with_conditioning), + num_head_channels=reversed_num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + self.up_blocks.append(up_block) + + # out + self.out = nn.Sequential( + nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[0], eps=norm_eps, affine=True), + nn.SiLU(), + zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=channels[0], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ), + ) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + down_block_additional_residuals: tuple[torch.Tensor] | None = None, + mid_block_additional_residual: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). + mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + # Additional residual conections for Controlnets + if down_block_additional_residuals is not None: + new_down_block_res_samples: list[torch.Tensor] = [] + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += [down_block_res_sample] + + down_block_res_samples = new_down_block_res_samples + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # Additional residual conections for Controlnets + if mid_block_additional_residual is not None: + h = h + mid_block_additional_residual + + # 6. up + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) + + # 7. output block + output: torch.Tensor = self.out(h) + + return output + + +class DiffusionModelEncoder(nn.Module): + """ + Classification Network based on the Encoder of the Diffusion Model, followed by fully connected layers. This network is based on + Wolleb et al. "Diffusion Models for Medical Anomaly Detection" (https://arxiv.org/abs/2203.04306). + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see _ResnetBlock) per level. + channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. + upcast_attention: if True, upcast attention operations to full precision. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "DiffusionModelEncoder expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "DiffusionModelEncoder expects with_conditioning=True when specifying the cross_attention_dim." + ) + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError("DiffusionModelEncoder expects all num_channels being multiple of norm_num_groups") + if len(channels) != len(attention_levels): + raise ValueError("DiffusionModelEncoder expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + self.in_channels = in_channels + self.block_out_channels = channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # down + self.down_blocks = nn.ModuleList([]) + output_channel = channels[0] + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_final_block = i == len(channels) # - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + + self.down_blocks.append(down_block) + + self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + for downsample_block in self.down_blocks: + h, _ = downsample_block(hidden_states=h, temb=emb, context=context) + + h = h.reshape(h.shape[0], -1) + output: torch.Tensor = self.out(h) + + return output diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py new file mode 100644 index 0000000000..d40a31a1da --- /dev/null +++ b/tests/test_diffusion_model_unet.py @@ -0,0 +1,535 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import DiffusionModelUNet + +UNCOND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": (1, 1, 2), + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, True, True), + "num_head_channels": (0, 2, 4), + "norm_num_groups": 8, + } + ], +] + +UNCOND_CASES_3D = [ + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": (0, 0, 4), + "norm_num_groups": 8, + } + ], +] + +COND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "resblock_updown": True, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "upcast_attention": True, + } + ], +] + +DROPOUT_OK = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "dropout_cattn": 0.25, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + } + ], +] + +DROPOUT_WRONG = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "dropout_cattn": 3.0, + } + ] +] + + +class TestDiffusionModelUNet2D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_2D) + def test_shape_unconditioned_models(self, input_param): + net = DiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + def test_timestep_with_wrong_shape(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long()) + + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with eval_mode(net): + result = net.forward(torch.rand((1, in_channels, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, out_channels, 16, 16)) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 12), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + def test_attention_levels_with_different_length_num_head_channels(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, False), + num_head_channels=(0, 2), + norm_num_groups=8, + ) + + def test_num_res_blocks_with_different_length_channels(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=(1, 1), + channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + def test_shape_conditioned_models(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + norm_num_groups=8, + num_head_channels=8, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + def test_with_conditioning_cross_attention_dim_none(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=None, + norm_num_groups=8, + ) + + def test_context_with_conditioning_none(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=False, + transformer_num_layers=1, + norm_num_groups=8, + ) + + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + context=torch.rand((1, 1, 3)), + ) + + def test_shape_conditioned_models_class_conditioning(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + class_labels=torch.randint(0, 2, (1,)).long(), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + def test_conditioned_models_no_class_labels(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + with self.assertRaises(ValueError): + net.forward(x=torch.rand((1, 1, 16, 32)), timesteps=torch.randint(0, 1000, (1,)).long()) + + def test_model_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + @parameterized.expand(COND_CASES_2D) + def test_conditioned_2d_models_shape(self, input_param): + net = DiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3))) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + +class TestDiffusionModelUNet3D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_3D) + def test_shape_unconditioned_models(self, input_param): + net = DiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = DiffusionModelUNet( + spatial_dims=3, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=4, + ) + with eval_mode(net): + result = net.forward(torch.rand((1, in_channels, 16, 16, 16)), torch.randint(0, 1000, (1,)).long()) + self.assertEqual(result.shape, (1, out_channels, 16, 16, 16)) + + def test_shape_conditioned_models(self): + net = DiffusionModelUNet( + spatial_dims=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(16, 16, 16), + attention_levels=(False, False, True), + norm_num_groups=16, + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 16, 16)), + timesteps=torch.randint(0, 1000, (1,)).long(), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + # Test dropout specification for cross-attention blocks + @parameterized.expand(DROPOUT_WRONG) + def test_wrong_dropout(self, input_param): + with self.assertRaises(ValueError): + _ = DiffusionModelUNet(**input_param) + + @parameterized.expand(DROPOUT_OK) + def test_right_dropout(self, input_param): + _ = DiffusionModelUNet(**input_param) + + +if __name__ == "__main__": + unittest.main() From c9f028eba9d338d3be6d0fa3b5c46a5004981369 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 13 Dec 2023 22:37:51 -0500 Subject: [PATCH 27/67] 6676 port generative networks controlnet (#7312) Part of #6676 . ### Description Ports the ControlNet. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham --- docs/source/networks.rst | 5 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/controlnet.py | 421 ++++++++++++++++++++++++++++++ tests/test_controlnet.py | 177 +++++++++++++ 4 files changed, 604 insertions(+) create mode 100644 monai/networks/nets/controlnet.py create mode 100644 tests/test_controlnet.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 417fb8ac73..0960fcdbc0 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -588,6 +588,11 @@ Nets .. autoclass:: DiffusionModelUNet :members: +`ControlNet` +~~~~~~~~~~~~ +.. autoclass:: ControlNet + :members: + `RegUNet` ~~~~~~~~~ .. autoclass:: RegUNet diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 31fbd73b4e..58cb652bae 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -18,6 +18,7 @@ from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus from .classifier import Classifier, Critic, Discriminator +from .controlnet import ControlNet from .daf3d import DAF3D from .densenet import ( DenseNet, diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py new file mode 100644 index 0000000000..d98755f401 --- /dev/null +++ b/monai/networks/nets/controlnet.py @@ -0,0 +1,421 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn.functional as F +from torch import nn + +from monai.networks.blocks import Convolution +from monai.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding +from monai.utils import ensure_tuple_rep + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Network to encode the conditioning into a latent space. + """ + + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, channels: Sequence[int]): + super().__init__() + + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.blocks = nn.ModuleList([]) + + for i in range(len(channels) - 1): + channel_in = channels[i] + channel_out = channels[i + 1] + self.blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=channel_in, + out_channels=channel_in, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=channel_in, + out_channels=channel_out, + strides=2, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.conv_out = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=channels[-1], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNet(nn.Module): + """ + Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image + Diffusion Models" (https://arxiv.org/abs/2302.05543) + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + num_res_blocks: number of residual blocks (see ResnetBlock) per level. + channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + conditioning_embedding_in_channels: number of input channels for the conditioning embedding. + conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + conditioning_embedding_in_channels: int = 1, + conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256), + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "to be specified when with_conditioning=True." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + ) + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError( + f"DiffusionModelUNet expects all channels to be a multiple of norm_num_groups, but got" + f" channels={channels} and norm_num_groups={norm_num_groups}" + ) + + if len(channels) != len(attention_levels): + raise ValueError( + f"DiffusionModelUNet expects channels to have the same length as attention_levels, but got " + f"channels={channels} and attention_levels={attention_levels}" + ) + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + f"num_head_channels should have the same length as attention_levels, but got channels={channels} and " + f"attention_levels={attention_levels} . For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_res_blocks) != len(channels): + raise ValueError( + f"`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + f"`num_channels`, but got num_res_blocks={num_res_blocks} and channels={channels}." + ) + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.in_channels = in_channels + self.block_out_channels = channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # control net conditioning embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + spatial_dims=spatial_dims, + in_channels=conditioning_embedding_in_channels, + channels=conditioning_embedding_num_channels, + out_channels=channels[0], + ) + + # down + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + output_channel = channels[0] + + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block.conv) + self.controlnet_down_blocks.append(controlnet_block) + + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_final_block = i == len(channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + + self.down_blocks.append(down_block) + + for _ in range(num_res_blocks[i]): + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + # + if not is_final_block: + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = channels[-1] + + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + ) -> tuple[list[torch.Tensor], torch.Tensor]: + """ + Args: + x: input tensor (N, C, H, W, [D]). + timesteps: timestep tensor (N,). + controlnet_cond: controlnet conditioning tensor (N, C, H, W, [D]) + conditioning_scale: conditioning scale. + context: context tensor (N, 1, cross_attention_dim), where cross_attention_dim is specified in the model init. + class_labels: context tensor (N, ). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + + h += controlnet_cond + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # 6. Control net blocks + controlnet_down_block_res_samples = [] + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples.append(down_block_res_sample) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample: torch.Tensor = self.controlnet_mid_block(h) + + # 6. scaling + down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples] + mid_block_res_sample *= conditioning_scale + + return down_block_res_samples, mid_block_res_sample diff --git a/tests/test_controlnet.py b/tests/test_controlnet.py new file mode 100644 index 0000000000..07dfa2e49b --- /dev/null +++ b/tests/test_controlnet.py @@ -0,0 +1,177 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.controlnet import ControlNet + +UNCOND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (4, 4, 4), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 4, + }, + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + }, + (1, 8, 4, 4), + ], +] + +UNCOND_CASES_3D = [ + [ + { + "spatial_dims": 3, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + }, + (1, 8, 4, 4, 4), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (4, 4, 4), + "num_head_channels": 4, + "attention_levels": (False, False, False), + "norm_num_groups": 4, + "resblock_updown": True, + }, + (1, 4, 4, 4, 4), + ], +] + +COND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "resblock_updown": True, + }, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "upcast_attention": True, + }, + (1, 8, 4, 4), + ], +] + + +class TestControlNet(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_2D + UNCOND_CASES_3D) + def test_shape_unconditioned_models(self, input_param, expected_output_shape): + input_param["conditioning_embedding_in_channels"] = input_param["in_channels"] + input_param["conditioning_embedding_num_channels"] = (input_param["channels"][0],) + net = ControlNet(**input_param) + with eval_mode(net): + x = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + timesteps = torch.randint(0, 1000, (1,)).long() + controlnet_cond = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + result = net.forward(x, timesteps=timesteps, controlnet_cond=controlnet_cond) + self.assertEqual(len(result[0]), 2 * len(input_param["channels"])) + self.assertEqual(result[1].shape, expected_output_shape) + + @parameterized.expand(COND_CASES_2D) + def test_shape_conditioned_models(self, input_param, expected_output_shape): + input_param["conditioning_embedding_in_channels"] = input_param["in_channels"] + input_param["conditioning_embedding_num_channels"] = (input_param["channels"][0],) + net = ControlNet(**input_param) + with eval_mode(net): + x = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + timesteps = torch.randint(0, 1000, (1,)).long() + controlnet_cond = torch.rand((1, 1) + (16,) * input_param["spatial_dims"]) + result = net.forward(x, timesteps=timesteps, controlnet_cond=controlnet_cond, context=torch.rand((1, 1, 3))) + self.assertEqual(len(result[0]), 2 * len(input_param["channels"])) + self.assertEqual(result[1].shape, expected_output_shape) + + +if __name__ == "__main__": + unittest.main() From 58b4575de50f18d34044bfe8e3bf3e391cfa2c66 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 14 Dec 2023 06:06:59 -0500 Subject: [PATCH 28/67] Adds patchgan discriminator (#7319) Part of #6676 . ### Description Adds a patchgan-style discriminator, both single scale and multiscale. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- docs/source/networks.rst | 8 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/patchgan_discriminator.py | 247 ++++++++++++++++++ tests/test_patch_gan_dicriminator.py | 179 +++++++++++++ 4 files changed, 435 insertions(+) create mode 100644 monai/networks/nets/patchgan_discriminator.py create mode 100644 tests/test_patch_gan_dicriminator.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 0960fcdbc0..8e79298941 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -756,6 +756,14 @@ Nets .. autoclass:: VQVAE :members: +`PatchGANDiscriminator` +~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: PatchDiscriminator + :members: + +.. autoclass:: MultiScalePatchDiscriminator + :members: + Utilities --------- .. automodule:: monai.networks.utils diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 58cb652bae..0f0d033d63 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -55,6 +55,7 @@ from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet from .milmodel import MILModel from .netadapter import NetAdapter +from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator from .quicknat import Quicknat from .regressor import Regressor from .regunet import GlobalNet, LocalNet, RegUNet diff --git a/monai/networks/nets/patchgan_discriminator.py b/monai/networks/nets/patchgan_discriminator.py new file mode 100644 index 0000000000..3b089616ce --- /dev/null +++ b/monai/networks/nets/patchgan_discriminator.py @@ -0,0 +1,247 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn as nn + +from monai.networks.blocks import Convolution +from monai.networks.layers import Act + + +class MultiScalePatchDiscriminator(nn.Sequential): + """ + Multi-scale Patch-GAN discriminator based on Pix2PixHD: + High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs (https://arxiv.org/abs/1711.11585) + + The Multi-scale discriminator made up of several PatchGAN discriminators, that process the images + at different spatial scales. + + Args: + num_d: number of discriminators + num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in the first + discriminator. Each subsequent discriminator has one additional layer, meaning the output size is halved. + spatial_dims: number of spatial dimensions (1D, 2D etc.) + channels: number of filters in the first convolutional layer (doubled for each subsequent layer) + in_channels: number of input channels + out_channels: number of output channels in each discriminator + kernel_size: kernel size of the convolution layers + activation: activation layer type + norm: normalisation type + bias: introduction of layer bias + dropout: probability of dropout applied, defaults to 0. + minimum_size_im: minimum spatial size of the input image. Introduced to make sure the architecture + requested isn't going to downsample the input image beyond value of 1. + last_conv_kernel_size: kernel size of the last convolutional layer. + """ + + def __init__( + self, + num_d: int, + num_layers_d: int, + spatial_dims: int, + channels: int, + in_channels: int, + out_channels: int = 1, + kernel_size: int = 4, + activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + norm: str | tuple = "BATCH", + bias: bool = False, + dropout: float | tuple = 0.0, + minimum_size_im: int = 256, + last_conv_kernel_size: int = 1, + ) -> None: + super().__init__() + self.num_d = num_d + self.num_layers_d = num_layers_d + self.num_channels = channels + self.padding = tuple([int((kernel_size - 1) / 2)] * spatial_dims) + for i_ in range(self.num_d): + num_layers_d_i = self.num_layers_d * (i_ + 1) + output_size = float(minimum_size_im) / (2**num_layers_d_i) + if output_size < 1: + raise AssertionError( + f"Your image size is too small to take in up to {i_} discriminators with num_layers = {num_layers_d_i}." + "Please reduce num_layers, reduce num_D or enter bigger images." + ) + subnet_d = PatchDiscriminator( + spatial_dims=spatial_dims, + channels=self.num_channels, + in_channels=in_channels, + out_channels=out_channels, + num_layers_d=num_layers_d_i, + kernel_size=kernel_size, + activation=activation, + norm=norm, + bias=bias, + padding=self.padding, + dropout=dropout, + last_conv_kernel_size=last_conv_kernel_size, + ) + + self.add_module("discriminator_%d" % i_, subnet_d) + + def forward(self, i: torch.Tensor) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]: + """ + Args: + i: Input tensor + + Returns: + list of outputs and another list of lists with the intermediate features + of each discriminator. + """ + + out: list[torch.Tensor] = [] + intermediate_features: list[list[torch.Tensor]] = [] + for disc in self.children(): + out_d: list[torch.Tensor] = disc(i) + out.append(out_d[-1]) + intermediate_features.append(out_d[:-1]) + + return out, intermediate_features + + +class PatchDiscriminator(nn.Sequential): + """ + Patch-GAN discriminator based on Pix2PixHD: + High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs (https://arxiv.org/abs/1711.11585) + + + Args: + spatial_dims: number of spatial dimensions (1D, 2D etc.) + channels: number of filters in the first convolutional layer (doubled for each subsequent layer) + in_channels: number of input channels + out_channels: number of output channels + num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in the discriminator. + kernel_size: kernel size of the convolution layers + act: activation type and arguments. Defaults to LeakyReLU. + norm: feature normalization type and arguments. Defaults to batch norm. + bias: whether to have a bias term in convolution blocks. Defaults to False. + padding: padding to be applied to the convolutional layers + dropout: proportion of dropout applied, defaults to 0. + last_conv_kernel_size: kernel size of the last convolutional layer. + """ + + def __init__( + self, + spatial_dims: int, + channels: int, + in_channels: int, + out_channels: int = 1, + num_layers_d: int = 3, + kernel_size: int = 4, + activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + norm: str | tuple = "BATCH", + bias: bool = False, + padding: int | Sequence[int] = 1, + dropout: float | tuple = 0.0, + last_conv_kernel_size: int | None = None, + ) -> None: + super().__init__() + self.num_layers_d = num_layers_d + self.num_channels = channels + if last_conv_kernel_size is None: + last_conv_kernel_size = kernel_size + + self.add_module( + "initial_conv", + Convolution( + spatial_dims=spatial_dims, + kernel_size=kernel_size, + in_channels=in_channels, + out_channels=channels, + act=activation, + bias=True, + norm=None, + dropout=dropout, + padding=padding, + strides=2, + ), + ) + + input_channels = channels + output_channels = channels * 2 + + # Initial Layer + for l_ in range(self.num_layers_d): + if l_ == self.num_layers_d - 1: + stride = 1 + else: + stride = 2 + layer = Convolution( + spatial_dims=spatial_dims, + kernel_size=kernel_size, + in_channels=input_channels, + out_channels=output_channels, + act=activation, + bias=bias, + norm=norm, + dropout=dropout, + padding=padding, + strides=stride, + ) + self.add_module("%d" % l_, layer) + input_channels = output_channels + output_channels = output_channels * 2 + + # Final layer + self.add_module( + "final_conv", + Convolution( + spatial_dims=spatial_dims, + kernel_size=last_conv_kernel_size, + in_channels=input_channels, + out_channels=out_channels, + bias=True, + conv_only=True, + padding=int((last_conv_kernel_size - 1) / 2), + dropout=0.0, + strides=1, + ), + ) + + self.apply(self.initialise_weights) + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + """ + Args: + x: input tensor + + Returns: + list of intermediate features, with the last element being the output. + """ + out = [x] + for submodel in self.children(): + intermediate_output = submodel(out[-1]) + out.append(intermediate_output) + + return out[1:] + + def initialise_weights(self, m: nn.Module) -> None: + """ + Initialise weights of Convolution and BatchNorm layers. + + Args: + m: instance of torch.nn.module (or of class inheriting torch.nn.module) + """ + classname = m.__class__.__name__ + if classname.find("Conv2d") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("Conv3d") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("Conv1d") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) diff --git a/tests/test_patch_gan_dicriminator.py b/tests/test_patch_gan_dicriminator.py new file mode 100644 index 0000000000..c19898e70d --- /dev/null +++ b/tests/test_patch_gan_dicriminator.py @@ -0,0 +1,179 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import MultiScalePatchDiscriminator, PatchDiscriminator +from tests.utils import test_script_save + +TEST_PATCHGAN = [ + [ + { + "num_layers_d": 3, + "spatial_dims": 2, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + }, + torch.rand([1, 3, 256, 512]), + (1, 8, 128, 256), + (1, 1, 32, 64), + ], + [ + { + "num_layers_d": 3, + "spatial_dims": 3, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + }, + torch.rand([1, 3, 256, 512, 256]), + (1, 8, 128, 256, 128), + (1, 1, 32, 64, 32), + ], +] + +TEST_MULTISCALE_PATCHGAN = [ + [ + { + "num_d": 2, + "num_layers_d": 3, + "spatial_dims": 2, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + "minimum_size_im": 256, + }, + torch.rand([1, 3, 256, 512]), + [(1, 1, 32, 64), (1, 1, 4, 8)], + [4, 7], + ], + [ + { + "num_d": 2, + "num_layers_d": 3, + "spatial_dims": 3, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + "minimum_size_im": 256, + }, + torch.rand([1, 3, 256, 512, 256]), + [(1, 1, 32, 64, 32), (1, 1, 4, 8, 4)], + [4, 7], + ], +] +TEST_TOO_SMALL_SIZE = [ + { + "num_d": 2, + "num_layers_d": 6, + "spatial_dims": 2, + "channels": 8, + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "activation": "LEAKYRELU", + "norm": "instance", + "bias": False, + "dropout": 0.1, + "minimum_size_im": 256, + } +] + + +class TestPatchGAN(unittest.TestCase): + @parameterized.expand(TEST_PATCHGAN) + def test_shape(self, input_param, input_data, expected_shape_feature, expected_shape_output): + net = PatchDiscriminator(**input_param) + with eval_mode(net): + result = net.forward(input_data) + self.assertEqual(tuple(result[0].shape), expected_shape_feature) + self.assertEqual(tuple(result[-1].shape), expected_shape_output) + + def test_script(self): + net = PatchDiscriminator( + num_layers_d=3, + spatial_dims=2, + channels=8, + in_channels=3, + out_channels=1, + kernel_size=3, + activation="LEAKYRELU", + norm="instance", + bias=False, + dropout=0.1, + ) + i = torch.rand([1, 3, 256, 512]) + test_script_save(net, i) + + +class TestMultiscalePatchGAN(unittest.TestCase): + @parameterized.expand(TEST_MULTISCALE_PATCHGAN) + def test_shape(self, input_param, input_data, expected_shape, features_lengths=None): + net = MultiScalePatchDiscriminator(**input_param) + with eval_mode(net): + result, features = net.forward(input_data) + for r_ind, r in enumerate(result): + self.assertEqual(tuple(r.shape), expected_shape[r_ind]) + for o_d_ind, o_d in enumerate(features): + self.assertEqual(len(o_d), features_lengths[o_d_ind]) + + def test_too_small_shape(self): + with self.assertRaises(AssertionError): + MultiScalePatchDiscriminator(**TEST_TOO_SMALL_SIZE[0]) + + def test_script(self): + net = MultiScalePatchDiscriminator( + num_d=2, + num_layers_d=3, + spatial_dims=2, + channels=8, + in_channels=3, + out_channels=1, + kernel_size=3, + activation="LEAKYRELU", + norm="instance", + bias=False, + dropout=0.1, + minimum_size_im=256, + ) + i = torch.rand([1, 3, 256, 512]) + test_script_save(net, i) + + +if __name__ == "__main__": + unittest.main() From aef4daf0115f385397163c64a38bc2b3d11546d8 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 19 Dec 2023 03:21:32 +0000 Subject: [PATCH 29/67] 6676 port generative networks spade (#7320) Towards #6676 . ### Description This adds SPADE-enabled autoencoder and diffusion_model_unet architectures. They are new implementations for each network, rather than options in the existing network, because @virginiafdez and I felt that adding additional options to the existing networks to enable spade compatibility significantly reduced the readability of them for users who were not interested in SPADE functionality. These are the last networks to be ported over. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Signed-off-by: Mark Graham Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Mark Graham --- docs/source/networks.rst | 14 + monai/networks/blocks/__init__.py | 1 + monai/networks/blocks/spade_norm.py | 96 ++ monai/networks/nets/__init__.py | 2 + monai/networks/nets/spade_autoencoderkl.py | 473 +++++++++ .../nets/spade_diffusion_model_unet.py | 908 ++++++++++++++++++ test_spade_autoencoderkl.py | 260 +++++ test_spade_diffusion_model_unet.py | 558 +++++++++++ 8 files changed, 2312 insertions(+) create mode 100644 monai/networks/blocks/spade_norm.py create mode 100644 monai/networks/nets/spade_autoencoderkl.py create mode 100644 monai/networks/nets/spade_diffusion_model_unet.py create mode 100644 test_spade_autoencoderkl.py create mode 100644 test_spade_diffusion_model_unet.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 8e79298941..79d5ef822e 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -258,6 +258,10 @@ N-Dim Fourier Transform .. autofunction:: monai.networks.blocks.fft_utils_t.fftshift .. autofunction:: monai.networks.blocks.fft_utils_t.ifftshift +`SPADE` +~~~~~~~ +.. autoclass:: monai.networks.blocks.spade_norm.SPADE + :members: Layers ------ @@ -588,6 +592,11 @@ Nets .. autoclass:: DiffusionModelUNet :members: +`SPADEDiffusionModelUNet` +~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: SPADEDiffusionModelUNet + :members: + `ControlNet` ~~~~~~~~~~~~ .. autoclass:: ControlNet @@ -618,6 +627,11 @@ Nets .. autoclass:: AutoencoderKL :members: +`SPADEAutoencoderKL` +~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: SPADEAutoencoderKL + :members: + `VarAutoEncoder` ~~~~~~~~~~~~~~~~ .. autoclass:: VarAutoEncoder diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index e67cb3376f..afb6664bd9 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -30,6 +30,7 @@ from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock from .segresnet_block import ResBlock from .selfattention import SABlock +from .spade_norm import SPADE from .squeeze_and_excitation import ( ChannelSELayer, ResidualSELayer, diff --git a/monai/networks/blocks/spade_norm.py b/monai/networks/blocks/spade_norm.py new file mode 100644 index 0000000000..b1046f3154 --- /dev/null +++ b/monai/networks/blocks/spade_norm.py @@ -0,0 +1,96 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import ADN, Convolution + + +class SPADE(nn.Module): + """ + Spatially Adaptive Normalization (SPADE) block, allowing for normalization of activations conditioned on a + semantic map. This block is used in SPADE-based image-to-image translation models, as described in + Semantic Image Synthesis with Spatially-Adaptive Normalization (https://arxiv.org/abs/1903.07291). + + Args: + label_nc: number of semantic labels + norm_nc: number of output channels + kernel_size: kernel size + spatial_dims: number of spatial dimensions + hidden_channels: number of channels in the intermediate gamma and beta layers + norm: type of base normalisation used before applying the SPADE normalisation + norm_params: parameters for the base normalisation + """ + + def __init__( + self, + label_nc: int, + norm_nc: int, + kernel_size: int = 3, + spatial_dims: int = 2, + hidden_channels: int = 64, + norm: str | tuple = "INSTANCE", + norm_params: dict | None = None, + ) -> None: + super().__init__() + + if norm_params is None: + norm_params = {} + if len(norm_params) != 0: + norm = (norm, norm_params) + self.param_free_norm = ADN( + act=None, dropout=0.0, norm=norm, norm_dim=spatial_dims, ordering="N", in_channels=norm_nc + ) + self.mlp_shared = Convolution( + spatial_dims=spatial_dims, + in_channels=label_nc, + out_channels=hidden_channels, + kernel_size=kernel_size, + norm=None, + act="LEAKYRELU", + ) + self.mlp_gamma = Convolution( + spatial_dims=spatial_dims, + in_channels=hidden_channels, + out_channels=norm_nc, + kernel_size=kernel_size, + act=None, + ) + self.mlp_beta = Convolution( + spatial_dims=spatial_dims, + in_channels=hidden_channels, + out_channels=norm_nc, + kernel_size=kernel_size, + act=None, + ) + + def forward(self, x: torch.Tensor, segmap: torch.Tensor) -> torch.Tensor: + """ + Args: + x: input tensor with shape (B, C, [spatial-dimensions]) where C is the number of semantic channels. + segmap: input segmentation map (B, C, [spatial-dimensions]) where C is the number of semantic channels. + The map will be interpolated to the dimension of x internally. + """ + + # Part 1. generate parameter-free normalized activations + normalized = self.param_free_norm(x) + + # Part 2. produce scaling and bias conditioned on semantic map + segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + out: torch.Tensor = normalized * (1 + gamma) + beta + return out diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 0f0d033d63..a7ce16ad64 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -106,6 +106,8 @@ seresnext50, seresnext101, ) +from .spade_autoencoderkl import SPADEAutoencoderKL +from .spade_diffusion_model_unet import SPADEDiffusionModelUNet from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR from .torchvision_fc import TorchVisionFCModel from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex diff --git a/monai/networks/nets/spade_autoencoderkl.py b/monai/networks/nets/spade_autoencoderkl.py new file mode 100644 index 0000000000..e064c19740 --- /dev/null +++ b/monai/networks/nets/spade_autoencoderkl.py @@ -0,0 +1,473 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import Convolution +from monai.networks.blocks.spade_norm import SPADE +from monai.networks.nets.autoencoderkl import Encoder, _AttentionBlock, _Upsample +from monai.utils import ensure_tuple_rep + +__all__ = ["SPADEAutoencoderKL"] + + +class SPADEResBlock(nn.Module): + """ + Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a + residual connection between input and output. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: input channels to the layer. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon for the normalisation. + out_channels: number of output channels. + label_nc: number of semantic channels for SPADE normalisation + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + norm_num_groups: int, + norm_eps: float, + out_channels: int, + label_nc: int, + spade_intermediate_channels: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.norm1 = SPADE( + label_nc=label_nc, + norm_nc=in_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "affine": False}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.norm2 = SPADE( + label_nc=label_nc, + norm_nc=out_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "affine": False}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + self.conv2 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.nin_shortcut: nn.Module + if self.in_channels != self.out_channels: + self.nin_shortcut = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + else: + self.nin_shortcut = nn.Identity() + + def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h, seg) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h, seg) + h = F.silu(h) + h = self.conv2(h) + + x = self.nin_shortcut(x) + + return x + h + + +class SPADEDecoder(nn.Module): + """ + Convolutional cascade upsampling from a spatial latent space into an image space. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + channels: sequence of block output channels. + in_channels: number of channels in the bottom layer (latent space) of the autoencoder. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from channels contain an attention block. + label_nc: number of semantic channels for SPADE normalisation. + with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + channels: Sequence[int], + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + label_nc: int, + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = channels + self.in_channels = in_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + self.label_nc = label_nc + + reversed_block_out_channels = list(reversed(channels)) + + blocks: list[nn.Module] = [] + + # Initial convolution + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=reversed_block_out_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + SPADEResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + SPADEResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + block_out_ch = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + block_in_ch = block_out_ch + block_out_ch = reversed_block_out_channels[i] + is_final_block = i == len(channels) - 1 + + for _ in range(reversed_num_res_blocks[i]): + blocks.append( + SPADEResBlock( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=block_out_ch, + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + block_in_ch = block_out_ch + + if reversed_attention_levels[i]: + blocks.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append(_Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=False)) + + blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + if isinstance(block, SPADEResBlock): + x = block(x, seg) + else: + x = block(x) + return x + + +class SPADEAutoencoderKL(nn.Module): + """ + Autoencoder model with KL-regularized latent space based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + label_nc: number of semantic channels for SPADE normalisation. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see ResBlock) per level. + channels: sequence of block output channels. + attention_levels: sequence of levels to add attention. + latent_channels: latent embedding dimension. + norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number. + norm_eps: epsilon for the normalization. + with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. + with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + label_nc: int, + in_channels: int = 1, + out_channels: int = 1, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + latent_channels: int = 3, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + with_encoder_nonlocal_attn: bool = True, + with_decoder_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): + raise ValueError("SPADEAutoencoderKL expects all channels being multiple of norm_num_groups") + + if len(channels) != len(attention_levels): + raise ValueError("SPADEAutoencoderKL expects channels being same size of attention_levels") + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) + + if len(num_res_blocks) != len(channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`channels`." + ) + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.encoder = Encoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + channels=channels, + out_channels=latent_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_encoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + ) + self.decoder = SPADEDecoder( + spatial_dims=spatial_dims, + channels=channels, + in_channels=latent_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + label_nc=label_nc, + with_nonlocal_attn=with_decoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + spade_intermediate_channels=spade_intermediate_channels, + ) + self.quant_conv_mu = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.quant_conv_log_sigma = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.post_quant_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=latent_channels, + out_channels=latent_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.latent_channels = latent_channels + + def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations. + + Args: + x: BxCx[SPATIAL DIMS] tensor + + """ + h = self.encoder(x) + z_mu = self.quant_conv_mu(h) + z_log_var = self.quant_conv_log_sigma(h) + z_log_var = torch.clamp(z_log_var, -30.0, 20.0) + z_sigma = torch.exp(z_log_var / 2) + + return z_mu, z_sigma + + def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor: + """ + From the mean and sigma representations resulting of encoding an image through the latent space, + obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and + adding the mean. + + Args: + z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image + z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image + + Returns: + sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE] + """ + eps = torch.randn_like(z_sigma) + z_vae = z_mu + eps * z_sigma + return z_vae + + def reconstruct(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + """ + Encodes and decodes an input image. + + Args: + x: BxCx[SPATIAL DIMENSIONS] tensor. + seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm. + Returns: + reconstructed image, of the same shape as input + """ + z_mu, _ = self.encode(x) + reconstruction = self.decode(z_mu, seg) + return reconstruction + + def decode(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + """ + Based on a latent space sample, forwards it through the Decoder. + + Args: + z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE] + seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm. + Returns: + decoded image tensor + """ + z = self.post_quant_conv(z) + dec: torch.Tensor = self.decoder(z, seg) + return dec + + def forward(self, x: torch.Tensor, seg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + reconstruction = self.decode(z, seg) + return reconstruction, z_mu, z_sigma + + def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + z_mu, z_sigma = self.encode(x) + z = self.sampling(z_mu, z_sigma) + return z + + def decode_stage_2_outputs(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + image = self.decode(z, seg) + return image diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py new file mode 100644 index 0000000000..d53327100e --- /dev/null +++ b/monai/networks/nets/spade_diffusion_model_unet.py @@ -0,0 +1,908 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +from torch import nn + +from monai.networks.blocks import Convolution +from monai.networks.blocks.spade_norm import SPADE +from monai.networks.nets.diffusion_model_unet import ( + _AttentionBlock, + _Downsample, + _ResnetBlock, + _SpatialTransformer, + _Upsample, + get_down_block, + get_mid_block, + get_timestep_embedding, + zero_module, +) +from monai.utils import ensure_tuple_rep, optional_import + +# To install xformers, use pip install xformers==0.0.16rc401 +xops, has_xformers = optional_import("xformers.ops") + + +__all__ = ["SPADEDiffusionModelUNet"] + + +class SPADEResnetBlock(nn.Module): + """ + Residual block with timestep conditioning and SPADE norm. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + label_nc: number of semantic channels for SPADE normalisation. + out_channels: number of output channels. + up: if True, performs upsampling. + down: if True, performs downsampling. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + label_nc: int, + out_channels: int | None = None, + up: bool = False, + down: bool = False, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = in_channels + self.emb_channels = temb_channels + self.out_channels = out_channels or in_channels + self.up = up + self.down = down + + self.norm1 = SPADE( + label_nc=label_nc, + norm_nc=in_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "eps": norm_eps, "affine": True}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + + self.nonlinearity = nn.SiLU() + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.upsample = self.downsample = None + if self.up: + self.upsample = _Upsample(spatial_dims, in_channels, use_conv=False) + elif down: + self.downsample = _Downsample(spatial_dims, in_channels, use_conv=False) + + self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) + + self.norm2 = SPADE( + label_nc=label_nc, + norm_nc=self.out_channels, + norm="GROUP", + norm_params={"num_groups": norm_num_groups, "eps": norm_eps, "affine": True}, + hidden_channels=spade_intermediate_channels, + kernel_size=3, + spatial_dims=spatial_dims, + ) + self.conv2 = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + self.skip_connection: nn.Module + + if self.out_channels == in_channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor, emb: torch.Tensor, seg: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h, seg) + h = self.nonlinearity(h) + + if self.upsample is not None: + if h.shape[0] >= 64: + x = x.contiguous() + h = h.contiguous() + x = self.upsample(x) + h = self.upsample(h) + elif self.downsample is not None: + x = self.downsample(x) + h = self.downsample(h) + + h = self.conv1(h) + + if self.spatial_dims == 2: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] + else: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] + h = h + temb + + h = self.norm2(h, seg) + h = self.nonlinearity(h) + h = self.conv2(h) + output: torch.Tensor = self.skip_connection(x) + h + return output + + +class SPADEUpBlock(nn.Module): + """ + Unet's up block containing resnet and upsamplers blocks. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + label_nc: number of semantic channels for SPADE normalisation. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + label_nc: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SPADEResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + seg: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb, seg) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class SPADEAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + label_nc: number of semantic channels for SPADE normalisation + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + label_nc: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SPADEResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + attentions.append( + _AttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + seg: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb, seg) + hidden_states = attn(hidden_states) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class SPADECrossAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + label_nc: number of semantic channels for SPADE normalisation. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + label_nc: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SPADEResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + ) + attentions.append( + _SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.upsampler: nn.Module | None + if add_upsample: + if resblock_updown: + self.upsampler = _ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = _Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + seg: torch.Tensor | None = None, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb, seg) + hidden_states = attn(hidden_states, context=context) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +def get_spade_up_block( + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_upsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + label_nc: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, +) -> nn.Module: + if with_attn: + return SPADEAttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + spade_intermediate_channels=spade_intermediate_channels, + ) + elif with_cross_attn: + return SPADECrossAttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + spade_intermediate_channels=spade_intermediate_channels, + ) + else: + return SPADEUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + label_nc=label_nc, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + spade_intermediate_channels=spade_intermediate_channels, + ) + + +class SPADEDiffusionModelUNet(nn.Module): + """ + UNet network with timestep embedding and attention mechanisms for conditioning, with added SPADE normalization for + semantic conditioning (Park et.al (2019): https://github.com/NVlabs/SPADE). An example tutorial can be found at + https://github.com/Project-MONAI/GenerativeModels/tree/main/tutorials/generative/2d_spade_ldm + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + label_nc: number of semantic channels for SPADE normalisation. + num_res_blocks: number of residual blocks (see ResnetBlock) per level. + num_channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + spade_intermediate_channels: number of intermediate channels for SPADE block layer + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + label_nc: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + spade_intermediate_channels: int = 128, + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "SPADEDiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "SPADEDiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + ) + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + raise ValueError("SPADEDiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + + if len(num_channels) != len(attention_levels): + raise ValueError("SPADEDiffusionModelUNet expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + + if len(num_res_blocks) != len(num_channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.in_channels = in_channels + self.block_out_channels = num_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + self.label_nc = label_nc + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = num_channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # down + self.down_blocks = nn.ModuleList([]) + output_channel = num_channels[0] + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + + self.down_blocks.append(down_block) + + # mid + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + temb_channels=time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + + # up + self.up_blocks = nn.ModuleList([]) + reversed_block_out_channels = list(reversed(num_channels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_head_channels = list(reversed(num_head_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)] + + is_final_block = i == len(num_channels) - 1 + + up_block = get_spade_up_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + prev_output_channel=prev_output_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=reversed_num_res_blocks[i] + 1, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(reversed_attention_levels[i] and not with_conditioning), + with_cross_attn=(reversed_attention_levels[i] and with_conditioning), + num_head_channels=reversed_num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + ) + + self.up_blocks.append(up_block) + + # out + self.out = nn.Sequential( + nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True), + nn.SiLU(), + zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=num_channels[0], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ), + ) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + seg: torch.Tensor, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + down_block_additional_residuals: tuple[torch.Tensor] | None = None, + mid_block_additional_residual: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm. + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). + mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + # Additional residual conections for Controlnets + if down_block_additional_residuals is not None: + new_down_block_res_samples: list[torch.Tensor] = [h] + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples.append(down_block_res_sample) + + down_block_res_samples = new_down_block_res_samples + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # Additional residual conections for Controlnets + if mid_block_additional_residual is not None: + h = h + mid_block_additional_residual + + # 6. up + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, seg=seg, temb=emb, context=context) + + # 7. output block + output: torch.Tensor = self.out(h) + + return output diff --git a/test_spade_autoencoderkl.py b/test_spade_autoencoderkl.py new file mode 100644 index 0000000000..6675a6db67 --- /dev/null +++ b/test_spade_autoencoderkl.py @@ -0,0 +1,260 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import SPADEAutoencoderKL + +CASES = [ + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": (1, 1, 2), + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 3, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16, 16), + (1, 3, 16, 16, 16), + (1, 1, 16, 16, 16), + (1, 4, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + "spade_intermediate_channels": 32, + }, + (1, 1, 16, 16), + (1, 3, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], +] + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class TestSPADEAutoEncoderKL(unittest.TestCase): + @parameterized.expand(CASES) + def test_shape(self, input_param, input_shape, input_seg, expected_shape, expected_latent_shape): + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device), torch.randn(input_seg).to(device)) + self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False), + latent_channels=8, + num_res_blocks=1, + norm_num_groups=16, + ) + + def test_model_channels_not_same_size_of_num_res_blocks(self): + with self.assertRaises(ValueError): + SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=(8, 8), + norm_num_groups=16, + ) + + def test_shape_encode(self): + input_param, input_shape, _, _, expected_latent_shape = CASES[0] + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + def test_shape_sampling(self): + input_param, _, _, _, expected_latent_shape = CASES[0] + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + def test_shape_decode(self): + input_param, _, input_seg_shape, expected_input_shape, latent_shape = CASES[0] + net = SPADEAutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device), torch.randn(input_seg_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + + def test_wrong_shape_decode(self): + net = SPADEAutoencoderKL( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + channels=(4, 4, 4), + latent_channels=4, + attention_levels=(False, False, False), + num_res_blocks=1, + norm_num_groups=4, + ) + with self.assertRaises(RuntimeError): + _ = net.decode(torch.randn((1, 1, 16, 16)).to(device), torch.randn((1, 6, 16, 16)).to(device)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test_spade_diffusion_model_unet.py b/test_spade_diffusion_model_unet.py new file mode 100644 index 0000000000..c8a2103cf6 --- /dev/null +++ b/test_spade_diffusion_model_unet.py @@ -0,0 +1,558 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import SPADEDiffusionModelUNet + +UNCOND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": (1, 1, 2), + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, True, True), + "num_head_channels": (0, 2, 4), + "norm_num_groups": 8, + "label_nc": 3, + } + ], +] + +UNCOND_CASES_3D = [ + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + "spade_intermediate_channels": 256, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": (0, 0, 4), + "norm_num_groups": 8, + "label_nc": 3, + } + ], +] + +COND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "resblock_updown": True, + "label_nc": 3, + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "upcast_attention": True, + "label_nc": 3, + } + ], +] + + +class TestSPADEDiffusionModelUNet2D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_2D) + def test_shape_unconditioned_models(self, input_param): + net = SPADEDiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward( + torch.rand((1, 1, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, input_param["label_nc"], 16, 16)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + def test_timestep_with_wrong_shape(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward( + torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long(), torch.rand((1, 3, 16, 16)) + ) + + def test_label_with_wrong_shape(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with self.assertRaises(RuntimeError): + with eval_mode(net): + net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 6, 16, 16))) + + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + with eval_mode(net): + result = net.forward( + torch.rand((1, in_channels, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 3, 16, 16)) + ) + self.assertEqual(result.shape, (1, out_channels, 16, 16)) + + def test_model_channels_not_multiple_of_norm_num_group(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 12), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + def test_attention_levels_with_different_length_num_head_channels(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + num_head_channels=(0, 2), + norm_num_groups=8, + ) + + def test_num_res_blocks_with_different_length_num_channels(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=(1, 1), + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + + def test_shape_conditioned_models(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + norm_num_groups=8, + num_head_channels=8, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + def test_with_conditioning_cross_attention_dim_none(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=None, + norm_num_groups=8, + ) + + def test_context_with_conditioning_none(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=False, + transformer_num_layers=1, + norm_num_groups=8, + ) + + with self.assertRaises(ValueError): + with eval_mode(net): + net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + context=torch.rand((1, 1, 3)), + ) + + def test_shape_conditioned_models_class_conditioning(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + class_labels=torch.randint(0, 2, (1,)).long(), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + def test_conditioned_models_no_class_labels(self): + net = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + with self.assertRaises(ValueError): + net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 32)), + ) + + def test_model_num_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + + @parameterized.expand(COND_CASES_2D) + def test_conditioned_2d_models_shape(self, input_param): + net = SPADEDiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward( + torch.rand((1, 1, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, input_param["label_nc"], 16, 16)), + torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16)) + + +class TestDiffusionModelUNet3D(unittest.TestCase): + @parameterized.expand(UNCOND_CASES_3D) + def test_shape_unconditioned_models(self, input_param): + net = SPADEDiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward( + torch.rand((1, 1, 16, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, input_param["label_nc"], 16, 16, 16)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + def test_shape_with_different_in_channel_out_channel(self): + in_channels = 6 + out_channels = 3 + net = SPADEDiffusionModelUNet( + spatial_dims=3, + label_nc=3, + in_channels=in_channels, + out_channels=out_channels, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=4, + ) + with eval_mode(net): + result = net.forward( + torch.rand((1, in_channels, 16, 16, 16)), + torch.randint(0, 1000, (1,)).long(), + torch.rand((1, 3, 16, 16, 16)), + ) + self.assertEqual(result.shape, (1, out_channels, 16, 16, 16)) + + def test_shape_conditioned_models(self): + net = SPADEDiffusionModelUNet( + spatial_dims=3, + label_nc=3, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(16, 16, 16), + attention_levels=(False, False, True), + norm_num_groups=16, + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 16, 16)), + timesteps=torch.randint(0, 1000, (1,)).long(), + seg=torch.rand((1, 3, 16, 16, 16)), + context=torch.rand((1, 1, 3)), + ) + self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + + +if __name__ == "__main__": + unittest.main() From 2fe972f21b3665d0baf5490244d6d538a37465b1 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 21 Dec 2023 10:38:30 +0000 Subject: [PATCH 30/67] Adds adversarial trainer Signed-off-by: Mark Graham --- monai/engines/__init__.py | 2 +- monai/engines/trainer.py | 283 +++++++++++++++++- .../test_integration_workflows_adversarial.py | 171 +++++++++++ 3 files changed, 453 insertions(+), 3 deletions(-) create mode 100644 tests/test_integration_workflows_adversarial.py diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index d8dc51f620..be65c39f70 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -12,7 +12,7 @@ from __future__ import annotations from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator -from .trainer import GanTrainer, SupervisedTrainer, Trainer +from .trainer import AdversarialTrainer, GanTrainer, SupervisedTrainer, Trainer from .utils import ( IterationEvents, PrepareBatch, diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 61b7028e11..dec5450725 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -22,7 +22,7 @@ from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform -from monai.utils import GanKeys, min_version, optional_import +from monai.utils import AdversarialIterationEvents, AdversarialKeys, GanKeys, min_version, optional_import from monai.utils.enums import CommonKeys as Keys from monai.utils.enums import EngineStatsKeys as ESKeys @@ -34,7 +34,7 @@ Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric") EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum") -__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer"] +__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer", "AdversarialTrainer"] class Trainer(Workflow): @@ -423,3 +423,282 @@ def _iteration( GanKeys.GLOSS: g_loss.item(), GanKeys.DLOSS: d_total_loss.item(), } + + +class AdversarialTrainer(Trainer): + """ + Standard supervised training workflow for adversarial loss enabled neural networks. + + Args: + device: an object representing the device on which to run. + max_epochs: the total epoch number for engine to run. + train_data_loader: Core ignite engines uses `DataLoader` for training loop batchdata. + g_network: ''generator'' (G) network architecture. + g_optimizer: G optimizer function. + g_loss_function: G loss function for adversarial training. + recon_loss_function: G loss function for reconstructions. + d_network: discriminator (D) network architecture. + d_optimizer: D optimizer function. + d_loss_function: D loss function for adversarial training.. + epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`. + non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to + the host. For other cases, this argument has no effect. + prepare_batch: function to parse image and label for current iteration. + iteration_update: the callable function for every iteration, expect to accept `engine` and `batchdata` as input + parameters. if not provided, use `self._iteration()` instead. + g_inferer: inference method to execute G model forward. Defaults to ``SimpleInferer()``. + d_inferer: inference method to execute D model forward. Defaults to ``SimpleInferer()``. + postprocessing: execute additional transformation for the model output data. Typically, several Tensor based + transforms composed by `Compose`. Defaults to None + key_train_metric: compute metric when every iteration completed, and save average value to engine.state.metrics + when epoch completed. key_train_metric is the main metric to compare and save the checkpoint into files. + additional_metrics: more Ignite metrics that also attach to Ignite Engine. + metric_cmp_fn: function to compare current key metric with previous best key metric value, it must accept 2 args + (current_metric, previous_best) and return a bool result: if `True`, will update 'best_metric` and + `best_metric_epoch` with current metric and epoch, default to `greater than`. + train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: + CheckpointHandler, StatsHandler, etc. + amp: whether to enable auto-mixed-precision training, default is False. + event_names: additional custom ignite events that will register to the engine. + new events can be a list of str or `ignite.engine.events.EventEnum`. + event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html + #ignite.engine.engine.Engine.register_events. + decollate: whether to decollate the batch-first data to a list of data after model computation, recommend + `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. + optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None. + more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. + amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + """ + + def __init__( + self, + device: torch.device | str, + max_epochs: int, + train_data_loader: Iterable | DataLoader, + g_network: torch.nn.Module, + g_optimizer: Optimizer, + g_loss_function: Callable, + recon_loss_function: Callable, + d_network: torch.nn.Module, + d_optimizer: Optimizer, + d_loss_function: Callable, + epoch_length: int | None = None, + non_blocking: bool = False, + prepare_batch: Callable = default_prepare_batch, + iteration_update: Callable | None = None, + g_inferer: Inferer | None = None, + d_inferer: Inferer | None = None, + postprocessing: Transform | None = None, + key_train_metric: dict[str, Metric] | None = None, + additional_metrics: dict[str, Metric] | None = None, + metric_cmp_fn: Callable = default_metric_cmp_fn, + train_handlers: Sequence | None = None, + amp: bool = False, + event_names: list[str | EventEnum | type[EventEnum]] | None = None, + event_to_attr: dict | None = None, + decollate: bool = True, + optim_set_to_none: bool = False, + to_kwargs: dict | None = None, + amp_kwargs: dict | None = None, + ): + super().__init__( + device=device, + max_epochs=max_epochs, + data_loader=train_data_loader, + epoch_length=epoch_length, + non_blocking=non_blocking, + prepare_batch=prepare_batch, + iteration_update=iteration_update, + postprocessing=postprocessing, + key_metric=key_train_metric, + additional_metrics=additional_metrics, + metric_cmp_fn=metric_cmp_fn, + handlers=train_handlers, + amp=amp, + event_names=event_names, + event_to_attr=event_to_attr, + decollate=decollate, + to_kwargs=to_kwargs, + amp_kwargs=amp_kwargs, + ) + + self.register_events(*AdversarialIterationEvents) + + self.state.g_network = g_network + self.state.g_optimizer = g_optimizer + self.state.g_loss_function = g_loss_function + self.state.recon_loss_function = recon_loss_function + + self.state.d_network = d_network + self.state.d_optimizer = d_optimizer + self.state.d_loss_function = d_loss_function + + self.g_inferer = SimpleInferer() if g_inferer is None else g_inferer + self.d_inferer = SimpleInferer() if d_inferer is None else d_inferer + + self.state.g_scaler = torch.cuda.amp.GradScaler() if self.amp else None + self.state.d_scaler = torch.cuda.amp.GradScaler() if self.amp else None + + self.optim_set_to_none = optim_set_to_none + self._complete_state_dict_user_keys() + + def _complete_state_dict_user_keys(self) -> None: + """ + This method appends to the _state_dict_user_keys AdversarialTrainer's elements that are required for + checkpoint saving. + + Follows the example found at: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine.state_dict + """ + self._state_dict_user_keys.extend( + ["g_network", "g_optimizer", "d_network", "d_optimizer", "g_scaler", "d_scaler"] + ) + + g_loss_state_dict = getattr(self.state.g_loss_function, "state_dict", None) + if callable(g_loss_state_dict): + self._state_dict_user_keys.append("g_loss_function") + + d_loss_state_dict = getattr(self.state.d_loss_function, "state_dict", None) + if callable(d_loss_state_dict): + self._state_dict_user_keys.append("d_loss_function") + + recon_loss_state_dict = getattr(self.state.recon_loss_function, "state_dict", None) + if callable(recon_loss_state_dict): + self._state_dict_user_keys.append("recon_loss_function") + + def _iteration( + self, engine: AdversarialTrainer, batchdata: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor | int | float | bool]: + """ + Callback function for the Adversarial Training processing logic of 1 iteration in Ignite Engine. + Return below items in a dictionary: + - IMAGE: image Tensor data for model input, already moved to device. + - LABEL: label Tensor data corresponding to the image, already moved to device. In case of Unsupervised + Learning this is equal to IMAGE. + - PRED: prediction result of model. + - LOSS: loss value computed by loss functions of the generator (reconstruction and adversarial summed up). + - AdversarialKeys.REALS: real images from the batch. Are the same as IMAGE. + - AdversarialKeys.FAKES: fake images generated by the generator. Are the same as PRED. + - AdversarialKeys.REAL_LOGITS: logits of the discriminator for the real images. + - AdversarialKeys.FAKE_LOGITS: logits of the discriminator for the fake images. + - AdversarialKeys.RECONSTRUCTION_LOSS: loss value computed by the reconstruction loss function. + - AdversarialKeys.GENERATOR_LOSS: loss value computed by the generator loss function. It is the + discriminator loss for the fake images. That is backpropagated through the generator only. + - AdversarialKeys.DISCRIMINATOR_LOSS: loss value computed by the discriminator loss function. It is the + discriminator loss for the real images and the fake images. That is backpropagated through the + discriminator only. + + Args: + engine: `AdversarialTrainer` to execute operation for an iteration. + batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. + + Raises: + ValueError: must provide batch data for current iteration. + + """ + + if batchdata is None: + raise ValueError("Must provide batch data for current iteration.") + batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) + + if len(batch) == 2: + inputs, targets = batch + args: tuple = () + kwargs: dict = {} + else: + inputs, targets, args, kwargs = batch + + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets, AdversarialKeys.REALS: inputs} + + def _compute_generator_loss() -> None: + engine.state.output[AdversarialKeys.FAKES] = engine.g_inferer( + inputs, engine.state.g_network, *args, **kwargs + ) + engine.state.output[Keys.PRED] = engine.state.output[AdversarialKeys.FAKES] + engine.fire_event(AdversarialIterationEvents.GENERATOR_FORWARD_COMPLETED) + + engine.state.output[AdversarialKeys.FAKE_LOGITS] = engine.d_inferer( + engine.state.output[AdversarialKeys.FAKES].float().contiguous(), engine.state.d_network, *args, **kwargs + ) + engine.fire_event(AdversarialIterationEvents.GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED) + + engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS] = engine.state.recon_loss_function( + engine.state.output[AdversarialKeys.FAKES], targets + ).mean() + engine.fire_event(AdversarialIterationEvents.RECONSTRUCTION_LOSS_COMPLETED) + + engine.state.output[AdversarialKeys.GENERATOR_LOSS] = engine.state.g_loss_function( + engine.state.output[AdversarialKeys.FAKE_LOGITS] + ).mean() + engine.fire_event(AdversarialIterationEvents.GENERATOR_LOSS_COMPLETED) + + # Train Generator + engine.state.g_network.train() + engine.state.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none) + + if engine.amp and engine.state.g_scaler is not None: + with torch.cuda.amp.autocast(**engine.amp_kwargs): + _compute_generator_loss() + + engine.state.output[Keys.LOSS] = ( + engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS] + + engine.state.output[AdversarialKeys.GENERATOR_LOSS] + ) + engine.state.g_scaler.scale(engine.state.output[Keys.LOSS]).backward() + engine.fire_event(AdversarialIterationEvents.GENERATOR_BACKWARD_COMPLETED) + engine.state.g_scaler.step(engine.state.g_optimizer) + engine.state.g_scaler.update() + else: + _compute_generator_loss() + ( + engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS] + + engine.state.output[AdversarialKeys.GENERATOR_LOSS] + ).backward() + engine.fire_event(AdversarialIterationEvents.GENERATOR_BACKWARD_COMPLETED) + engine.state.g_optimizer.step() + engine.fire_event(AdversarialIterationEvents.GENERATOR_MODEL_COMPLETED) + + def _compute_discriminator_loss() -> None: + engine.state.output[AdversarialKeys.REAL_LOGITS] = engine.d_inferer( + engine.state.output[AdversarialKeys.REALS].contiguous().detach(), + engine.state.d_network, + *args, + **kwargs, + ) + engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_REALS_FORWARD_COMPLETED) + + engine.state.output[AdversarialKeys.FAKE_LOGITS] = engine.d_inferer( + engine.state.output[AdversarialKeys.FAKES].contiguous().detach(), + engine.state.d_network, + *args, + **kwargs, + ) + engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_FAKES_FORWARD_COMPLETED) + + engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS] = engine.state.d_loss_function( + engine.state.output[AdversarialKeys.REAL_LOGITS], engine.state.output[AdversarialKeys.FAKE_LOGITS] + ).mean() + engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_LOSS_COMPLETED) + + # Train Discriminator + engine.state.d_network.train() + engine.state.d_network.zero_grad(set_to_none=engine.optim_set_to_none) + + if engine.amp and engine.state.d_scaler is not None: + with torch.cuda.amp.autocast(**engine.amp_kwargs): + _compute_discriminator_loss() + + engine.state.d_scaler.scale(engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS]).backward() + engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_BACKWARD_COMPLETED) + engine.state.d_scaler.step(engine.state.d_optimizer) + engine.state.d_scaler.update() + else: + _compute_discriminator_loss() + engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS].backward() + engine.state.d_optimizer.step() + + return engine.state.output diff --git a/tests/test_integration_workflows_adversarial.py b/tests/test_integration_workflows_adversarial.py new file mode 100644 index 0000000000..e0ff93bbe3 --- /dev/null +++ b/tests/test_integration_workflows_adversarial.py @@ -0,0 +1,171 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import shutil +import tempfile +import unittest +from glob import glob + +import nibabel as nib +import numpy as np +import torch + +import monai +from monai.data import create_test_image_2d +from monai.engines import AdversarialTrainer +from monai.handlers import CheckpointSaver, StatsHandler, TensorBoardStatsHandler +from monai.networks.nets import AutoEncoder, Discriminator +from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, RandFlipd, ScaleIntensityd +from monai.utils import AdversarialKeys as Keys +from monai.utils import CommonKeys, set_determinism +from tests.utils import DistTestCase, TimedCall, skip_if_quick + + +def run_training_test(root_dir, device="cuda:0"): + learning_rate = 2e-4 + real_label = 1 + fake_label = 0 + + real_images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) + train_files = [{CommonKeys.IMAGE: img, CommonKeys.LABEL: img} for img in zip(real_images)] + + # prepare real data + train_transforms = Compose( + [ + LoadImaged(keys=[CommonKeys.IMAGE, CommonKeys.LABEL]), + EnsureChannelFirstd(keys=[CommonKeys.IMAGE, CommonKeys.LABEL], channel_dim=2), + ScaleIntensityd(keys=[CommonKeys.IMAGE]), + RandFlipd(keys=[CommonKeys.IMAGE, CommonKeys.LABEL], prob=0.5), + ] + ) + train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) + train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) + + # Create Discriminator + discriminator_net = Discriminator( + in_shape=(1, 64, 64), channels=(8, 16, 32, 64, 1), strides=(2, 2, 2, 2, 1), num_res_units=1, kernel_size=5 + ).to(device) + discriminator_opt = torch.optim.Adam(discriminator_net.parameters(), learning_rate) + discriminator_loss_criterion = torch.nn.BCELoss() + + def discriminator_loss(real_logits, fake_logits): + real_target = real_logits.new_full((real_logits.shape[0], 1), real_label) + fake_target = fake_logits.new_full((fake_logits.shape[0], 1), fake_label) + real_loss = discriminator_loss_criterion(real_logits, real_target) + fake_loss = discriminator_loss_criterion(fake_logits.detach(), fake_target) + return torch.div(torch.add(real_loss, fake_loss), 2) + + # Create Generator + generator_network = AutoEncoder( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(8, 16, 32, 64), + strides=(2, 2, 2, 2), + num_res_units=1, + num_inter_units=1, + ) + generator_network = generator_network.to(device) + generator_optimiser = torch.optim.Adam(generator_network.parameters(), learning_rate) + generator_loss_criterion = torch.nn.MSELoss() + + def reconstruction_loss(recon_images, real_images): + return generator_loss_criterion(recon_images, real_images) + + def generator_loss(fake_logits): + fake_target = fake_logits.new_full((fake_logits.shape[0], 1), real_label) + recon_loss = discriminator_loss_criterion(fake_logits.detach(), fake_target) + return recon_loss + + key_train_metric = None + + train_handlers = [ + StatsHandler( + name="training_loss", + output_transform=lambda x: { + Keys.RECONSTRUCTION_LOSS: x[Keys.RECONSTRUCTION_LOSS], + Keys.DISCRIMINATOR_LOSS: x[Keys.DISCRIMINATOR_LOSS], + Keys.GENERATOR_LOSS: x[Keys.GENERATOR_LOSS], + }, + ), + TensorBoardStatsHandler( + log_dir=root_dir, + tag_name="training_loss", + output_transform=lambda x: { + Keys.RECONSTRUCTION_LOSS: x[Keys.RECONSTRUCTION_LOSS], + Keys.DISCRIMINATOR_LOSS: x[Keys.DISCRIMINATOR_LOSS], + Keys.GENERATOR_LOSS: x[Keys.GENERATOR_LOSS], + }, + ), + CheckpointSaver( + save_dir=root_dir, + save_dict={"g_net": generator_network, "d_net": discriminator_net}, + save_interval=2, + epoch_level=True, + ), + ] + + num_epochs = 5 + + trainer = AdversarialTrainer( + device=device, + max_epochs=num_epochs, + train_data_loader=train_loader, + g_network=generator_network, + g_optimizer=generator_optimiser, + g_loss_function=generator_loss, + recon_loss_function=reconstruction_loss, + d_network=discriminator_net, + d_optimizer=discriminator_opt, + d_loss_function=discriminator_loss, + non_blocking=True, + key_train_metric=key_train_metric, + train_handlers=train_handlers, + ) + trainer.run() + + return trainer.state + + +@skip_if_quick +class IntegrationWorkflowsAdversarialTrainer(DistTestCase): + def setUp(self): + set_determinism(seed=0) + + self.data_dir = tempfile.mkdtemp() + for i in range(40): + im, _ = create_test_image_2d(64, 64, num_objs=3, rad_max=14, num_seg_classes=1, channel_dim=-1) + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(self.data_dir, f"img{i:d}.nii.gz")) + + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0") + monai.config.print_config() + + def tearDown(self): + set_determinism(seed=None) + shutil.rmtree(self.data_dir) + + @TimedCall(seconds=200, daemon=False) + def test_training(self): + torch.manual_seed(0) + + finish_state = run_training_test(self.data_dir, device=self.device) + + # Assert AdversarialTrainer training finished + self.assertEqual(finish_state.iteration, 100) + self.assertEqual(finish_state.epoch, 5) + + +if __name__ == "__main__": + unittest.main() From 33dce3a0d0909e26f482672956ac2609327c6dc1 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 21 Dec 2023 10:49:11 +0000 Subject: [PATCH 31/67] Updates docs Signed-off-by: Mark Graham --- docs/source/engines.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/engines.rst b/docs/source/engines.rst index afb2682822..a015c7b2a3 100644 --- a/docs/source/engines.rst +++ b/docs/source/engines.rst @@ -30,6 +30,11 @@ Workflows .. autoclass:: GanTrainer :members: +`AdversarialTrainer` +~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: AdversarialTrainer + :members: + `Evaluator` ~~~~~~~~~~~ .. autoclass:: Evaluator From 023b6bbdfa1f8bb6c42920e158e3019eae33d70e Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 21 Dec 2023 10:49:27 +0000 Subject: [PATCH 32/67] Adds preparebatch Signed-off-by: Mark Graham --- monai/engines/utils.py | 79 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 1 deletion(-) diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 02c718cd14..a76645ce36 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -13,9 +13,10 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union, cast import torch +import torch.nn as nn from monai.config import IgniteInfo from monai.transforms import apply_transform @@ -36,6 +37,8 @@ "PrepareBatch", "PrepareBatchDefault", "PrepareBatchExtraInput", + "DiffusionPrepareBatch", + "VPredictionPrepareBatch", "default_make_latent", "engine_apply_transform", "default_metric_cmp_fn", @@ -238,6 +241,80 @@ def _get_data(key: str) -> torch.Tensor: return cast(torch.Tensor, image), cast(torch.Tensor, label), tuple(args_), kwargs_ +class DiffusionPrepareBatch(PrepareBatch): + """ + This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. + + Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and + return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise". + This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided. + + If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition + field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". + + """ + + def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None) -> None: + self.condition_name = condition_name + self.num_train_timesteps = num_train_timesteps + + def get_noise(self, images: torch.Tensor) -> torch.Tensor: + """Returns the noise tensor for input tensor `images`, override this for different noise distributions.""" + return torch.randn_like(images) + + def get_timesteps(self, images: torch.Tensor) -> torch.Tensor: + """Get a timestep, by default this is a random integer between 0 and `self.num_train_timesteps`.""" + return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long() + + def get_target(self, images: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + """Return the target for the loss function, this is the `noise` value by default.""" + return noise + + def __call__( + self, + batchdata: Dict[str, torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + non_blocking: bool = False, + **kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor, tuple, dict]: + images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs) + noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs) + timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs) + + target = self.get_target(images, noise, timesteps).to(device, non_blocking=non_blocking, **kwargs) + infer_kwargs = {"noise": noise, "timesteps": timesteps} + + if self.condition_name is not None and isinstance(batchdata, Mapping): + infer_kwargs["conditioning"] = batchdata[self.condition_name].to( + device, non_blocking=non_blocking, **kwargs + ) + + # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value + return images, target, (), infer_kwargs + + +class VPredictionPrepareBatch(DiffusionPrepareBatch): + """ + This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. + + Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and + from this compute the velocity using the provided scheduler. This value is used as the target in place of the + noise field itself although the noise is field is in the kwargs under the key "noise". This assumes the inferer + being used in conjunction with this class expects a "noise" parameter to be provided. + + If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition + field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". + + """ + + def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_name: Optional[str] = None) -> None: + super().__init__(num_train_timesteps=num_train_timesteps, condition_name=condition_name) + self.scheduler = scheduler + + def get_target(self, images, noise, timesteps): + return self.scheduler.get_velocity(images, noise, timesteps) + + def default_make_latent( num_latents: int, latent_size: int, From a017b8bee02e792824373004601539a805d45551 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 21 Dec 2023 11:11:33 +0000 Subject: [PATCH 33/67] Outline test, need schedulers and inferer implemented to finish Signed-off-by: Mark Graham --- tests/test_prepare_batch_diffusion.py | 88 +++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 tests/test_prepare_batch_diffusion.py diff --git a/tests/test_prepare_batch_diffusion.py b/tests/test_prepare_batch_diffusion.py new file mode 100644 index 0000000000..c72053de99 --- /dev/null +++ b/tests/test_prepare_batch_diffusion.py @@ -0,0 +1,88 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized +from monai.engines.utils import DiffusionPrepareBatch +from monai.networks.nets import DiffusionModelUNet +from monai.engines import SupervisedEvaluator +from tests.utils import assert_allclose + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8, 8), + ], +] + + + +class TestPrepareBatchDiffusion(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_content(self, input_args, image_size): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataloader = [ + { + "image": torch.randn(image_size).to(device), + } + ] + # set up engine + network = DiffusionModelUNet(**input_args).to(device) + num_train_timesteps = 10 + scheduler = DDPMScheduler(num_train_timesteps=num_train_timesteps) + inferer = DiffusionInferer() + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=dataloader, + epoch_length=1, + network=network, + non_blocking=True, + prepare_batch=DiffusionPrepareBatch(num_train_timesteps=20), + decollate=False, + ) + evaluator.run() + output = evaluator.state.output + assert_allclose(output["image"], torch.tensor([1, 2], device=device)) + for k, v in output["pred"].items(): + if isinstance(v, torch.Tensor): + assert_allclose(v, expected_value[k].to(device)) + else: + self.assertEqual(v, expected_value[k]) + + +if __name__ == "__main__": + unittest.main() From 64735eb4d1ca153d43909854008079f3feef2180 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 3 Jan 2024 16:30:24 +0000 Subject: [PATCH 34/67] 6676 port diffusion schedulers (#7332) Towards #6676 . ### Description This adds some base classes for DDPM noise schedulers + three scheduler types. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham --- docs/source/networks.rst | 20 ++ monai/networks/schedulers/__init__.py | 17 ++ monai/networks/schedulers/ddim.py | 284 ++++++++++++++++++++++ monai/networks/schedulers/ddpm.py | 243 +++++++++++++++++++ monai/networks/schedulers/pndm.py | 316 +++++++++++++++++++++++++ monai/networks/schedulers/scheduler.py | 203 ++++++++++++++++ monai/utils/misc.py | 4 +- tests/test_scheduler_ddim.py | 83 +++++++ tests/test_scheduler_ddpm.py | 104 ++++++++ tests/test_scheduler_pndm.py | 108 +++++++++ 10 files changed, 1380 insertions(+), 2 deletions(-) create mode 100644 monai/networks/schedulers/__init__.py create mode 100644 monai/networks/schedulers/ddim.py create mode 100644 monai/networks/schedulers/ddpm.py create mode 100644 monai/networks/schedulers/pndm.py create mode 100644 monai/networks/schedulers/scheduler.py create mode 100644 tests/test_scheduler_ddim.py create mode 100644 tests/test_scheduler_ddpm.py create mode 100644 tests/test_scheduler_pndm.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 79d5ef822e..f9375f1e97 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -778,6 +778,26 @@ Nets .. autoclass:: MultiScalePatchDiscriminator :members: +Diffusion Schedulers +-------------------- +.. autoclass:: monai.networks.schedulers.Scheduler + :members: + +`DDPM Scheduler` +~~~~~~~~~~~~~~~~ +.. autoclass:: monai.networks.schedulers.DDPMScheduler + :members: + +`DDIM Scheduler` +~~~~~~~~~~~~~~~~ +.. autoclass:: monai.networks.schedulers.DDIMScheduler + :members: + +`PNDM Scheduler` +~~~~~~~~~~~~~~~~ +.. autoclass:: monai.networks.schedulers.PNDMScheduler + :members: + Utilities --------- .. automodule:: monai.networks.utils diff --git a/monai/networks/schedulers/__init__.py b/monai/networks/schedulers/__init__.py new file mode 100644 index 0000000000..29e9020d65 --- /dev/null +++ b/monai/networks/schedulers/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from .ddim import DDIMScheduler +from .ddpm import DDPMScheduler +from .pndm import PNDMScheduler +from .scheduler import NoiseSchedules, Scheduler diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py new file mode 100644 index 0000000000..ec47ff8dc6 --- /dev/null +++ b/monai/networks/schedulers/ddim.py @@ -0,0 +1,284 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import numpy as np +import torch + +from monai.utils import StrEnum + +from .scheduler import Scheduler + + +class DDIMPredictionType(StrEnum): + """ + Set of valid prediction type names for the DDIM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + sample: directly predicting the noisy sample + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + EPSILON = "epsilon" + SAMPLE = "sample" + V_PREDICTION = "v_prediction" + + +class DDIMScheduler(Scheduler): + """ + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. Based on: Song et al. "Denoising Diffusion + Implicit Models" https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store + clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. + set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one. + For the final step there is no previous alpha. When this option is `True` the previous alpha product is + fixed to `1`, otherwise it uses the value of alpha at step 0. + steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + prediction_type: member of DDPMPredictionType + schedule_args: arguments to pass to the schedule function + + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + schedule: str = "linear_beta", + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = DDIMPredictionType.EPSILON, + **schedule_args, + ) -> None: + super().__init__(num_train_timesteps, schedule, **schedule_args) + + if prediction_type not in DDIMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of DDIMPredictionType") + + self.prediction_type = prediction_type + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64)) + + self.clip_sample = clip_sample + self.steps_offset = steps_offset + + # default the number of inference timesteps to the number of train steps + self.num_inference_steps: int + self.set_timesteps(self.num_train_timesteps) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps += self.steps_offset + + def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor: + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance: torch.Tensor = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + eta: float = 0.0, + generator: torch.Generator | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + eta: weight of noise for added noise in diffusion step. + predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. + generator: random number generator. + + Returns: + pred_prev_sample: Predicted previous sample + pred_original_sample: Predicted original sample + """ + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - model_output -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.prediction_type == DDIMPredictionType.EPSILON: + pred_original_sample = (sample - (beta_prod_t**0.5) * model_output) / (alpha_prod_t**0.5) + pred_epsilon = model_output + elif self.prediction_type == DDIMPredictionType.SAMPLE: + pred_original_sample = model_output + pred_epsilon = (sample - (alpha_prod_t**0.5) * pred_original_sample) / (beta_prod_t**0.5) + elif self.prediction_type == DDIMPredictionType.V_PREDICTION: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # 4. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance**0.5 + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * pred_epsilon + + # 7. compute x_t-1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + + if eta > 0: + # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 + device: torch.device = torch.device(model_output.device if torch.is_tensor(model_output) else "cpu") + noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) + variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise + + pred_prev_sample = pred_prev_sample + variance + + return pred_prev_sample, pred_original_sample + + def reversed_step( + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict the sample at the next timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + + Returns: + pred_prev_sample: Predicted previous sample + pred_original_sample: Predicted original sample + """ + # See Appendix F at https://arxiv.org/pdf/2105.05233.pdf, or Equation (6) in https://arxiv.org/pdf/2203.04306.pdf + + # Notation ( -> + # - model_output -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_post_sample -> "x_t+1" + + # 1. get previous step value (=t+1) + prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas at timestep t+1 + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + + if self.prediction_type == DDIMPredictionType.EPSILON: + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.prediction_type == DDIMPredictionType.SAMPLE: + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.prediction_type == DDIMPredictionType.V_PREDICTION: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # 4. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon + + # 6. compute x_t+1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + return pred_post_sample, pred_original_sample diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py new file mode 100644 index 0000000000..a5173a1b65 --- /dev/null +++ b/monai/networks/schedulers/ddpm.py @@ -0,0 +1,243 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import numpy as np +import torch + +from monai.utils import StrEnum + +from .scheduler import Scheduler + + +class DDPMVarianceType(StrEnum): + """ + Valid names for DDPM Scheduler's `variance_type` argument. Options to clip the variance used when adding noise + to the denoised sample. + """ + + FIXED_SMALL = "fixed_small" + FIXED_LARGE = "fixed_large" + LEARNED = "learned" + LEARNED_RANGE = "learned_range" + + +class DDPMPredictionType(StrEnum): + """ + Set of valid prediction type names for the DDPM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + sample: directly predicting the noisy sample + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + EPSILON = "epsilon" + SAMPLE = "sample" + V_PREDICTION = "v_prediction" + + +class DDPMScheduler(Scheduler): + """ + Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and + Langevin dynamics sampling. Based on: Ho et al., "Denoising Diffusion Probabilistic Models" + https://arxiv.org/abs/2006.11239 + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store + variance_type: member of DDPMVarianceType + clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. + prediction_type: member of DDPMPredictionType + schedule_args: arguments to pass to the schedule function + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + schedule: str = "linear_beta", + variance_type: str = DDPMVarianceType.FIXED_SMALL, + clip_sample: bool = True, + prediction_type: str = DDPMPredictionType.EPSILON, + **schedule_args, + ) -> None: + super().__init__(num_train_timesteps, schedule, **schedule_args) + + if variance_type not in DDPMVarianceType.__members__.values(): + raise ValueError("Argument `variance_type` must be a member of `DDPMVarianceType`") + + if prediction_type not in DDPMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of `DDPMPredictionType`") + + self.clip_sample = clip_sample + self.variance_type = variance_type + self.prediction_type = prediction_type + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + + def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: + """ + Compute the mean of the posterior at timestep t. + + Args: + timestep: current timestep. + x0: the noise-free input. + x_t: the input noised to timestep t. + + Returns: + Returns the mean + """ + # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0), + # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf) + alpha_t = self.alphas[timestep] + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + + x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t) + x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) + + mean: torch.Tensor = x_0_coefficient * x_0 + x_t_coefficient * x_t + + return mean + + def _get_variance(self, timestep: int, predicted_variance: torch.Tensor | None = None) -> torch.Tensor: + """ + Compute the variance of the posterior at timestep t. + + Args: + timestep: current timestep. + predicted_variance: variance predicted by the model. + + Returns: + Returns the variance + """ + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance: torch.Tensor = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep] + # hacks - were probably added for training stability + if self.variance_type == DDPMVarianceType.FIXED_SMALL: + variance = torch.clamp(variance, min=1e-20) + elif self.variance_type == DDPMVarianceType.FIXED_LARGE: + variance = self.betas[timestep] + elif self.variance_type == DDPMVarianceType.LEARNED and predicted_variance is not None: + return predicted_variance + elif self.variance_type == DDPMVarianceType.LEARNED_RANGE and predicted_variance is not None: + min_log = variance + max_log = self.betas[timestep] + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def step( + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator: torch.Generator | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + generator: random number generator. + + Returns: + pred_prev_sample: Predicted previous sample + """ + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.prediction_type == DDPMPredictionType.EPSILON: + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.prediction_type == DDPMPredictionType.SAMPLE: + pred_original_sample = model_output + elif self.prediction_type == DDPMPredictionType.V_PREDICTION: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + + # 3. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t + current_sample_coeff = self.alphas[timestep] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + variance = 0 + if timestep > 0: + noise = torch.randn( + model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator + ).to(model_output.device) + variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise + + pred_prev_sample = pred_prev_sample + variance + + return pred_prev_sample, pred_original_sample diff --git a/monai/networks/schedulers/pndm.py b/monai/networks/schedulers/pndm.py new file mode 100644 index 0000000000..c0728bbdff --- /dev/null +++ b/monai/networks/schedulers/pndm.py @@ -0,0 +1,316 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from typing import Any + +import numpy as np +import torch + +from monai.utils import StrEnum + +from .scheduler import Scheduler + + +class PNDMPredictionType(StrEnum): + """ + Set of valid prediction type names for the PNDM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + EPSILON = "epsilon" + V_PREDICTION = "v_prediction" + + +class PNDMScheduler(Scheduler): + """ + Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, + namely Runge-Kutta method and a linear multi-step method. Based on: Liu et al., + "Pseudo Numerical Methods for Diffusion Models on Manifolds" https://arxiv.org/abs/2202.09778 + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store + skip_prk_steps: + allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required + before plms step. + set_alpha_to_one: + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + prediction_type: member of DDPMPredictionType + steps_offset: + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + schedule_args: arguments to pass to the schedule function + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + schedule: str = "linear_beta", + skip_prk_steps: bool = False, + set_alpha_to_one: bool = False, + prediction_type: str = PNDMPredictionType.EPSILON, + steps_offset: int = 0, + **schedule_args, + ) -> None: + super().__init__(num_train_timesteps, schedule, **schedule_args) + + if prediction_type not in PNDMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of PNDMPredictionType") + + self.prediction_type = prediction_type + + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at formula (9), (12), (13) and the Algorithm 2. + self.pndm_order = 4 + + self.skip_prk_steps = skip_prk_steps + self.steps_offset = steps_offset + + # running values + self.cur_model_output = torch.Tensor() + self.counter = 0 + self.cur_sample = torch.Tensor() + self.ets: list = [] + + # default the number of inference timesteps to the number of train steps + self.set_timesteps(num_train_timesteps) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64) + self._timesteps += self.steps_offset + + if self.skip_prk_steps: + # for some models like stable diffusion the prk steps can/should be skipped to + # produce better results. When using PNDM with `self.skip_prk_steps` the implementation + # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 + self.prk_timesteps = np.array([]) + self.plms_timesteps = self._timesteps[::-1] + + else: + prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( + np.array([0, self.num_train_timesteps // num_inference_steps // 2]), self.pndm_order + ) + self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() + self.plms_timesteps = self._timesteps[:-3][ + ::-1 + ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy + + timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + # update num_inference_steps - necessary if we use prk steps + self.num_inference_steps = len(self.timesteps) + + self.ets = [] + self.counter = 0 + + def step(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> tuple[torch.Tensor, Any]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + Returns: + pred_prev_sample: Predicted previous sample + """ + # return a tuple for consistency with samplers that return (previous pred, original sample pred) + + if self.counter < len(self.prk_timesteps) and not self.skip_prk_steps: + return self.step_prk(model_output=model_output, timestep=timestep, sample=sample), None + else: + return self.step_plms(model_output=model_output, timestep=timestep, sample=sample), None + + def step_prk(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> torch.Tensor: + """ + Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the + solution to the differential equation. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + + Returns: + pred_prev_sample: Predicted previous sample + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + diff_to_prev = 0 if self.counter % 2 else self.num_train_timesteps // self.num_inference_steps // 2 + prev_timestep = timestep - diff_to_prev + timestep = self.prk_timesteps[self.counter // 4 * 4] + + if self.counter % 4 == 0: + self.cur_model_output = 1 / 6 * model_output + self.ets.append(model_output) + self.cur_sample = sample + elif (self.counter - 1) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 2) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 3) % 4 == 0: + model_output = self.cur_model_output + 1 / 6 * model_output + self.cur_model_output = torch.Tensor() + + # cur_sample should not be an empty torch.Tensor() + cur_sample = self.cur_sample if self.cur_sample.numel() != 0 else sample + + prev_sample: torch.Tensor = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) + self.counter += 1 + + return prev_sample + + def step_plms(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> Any: + """ + Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple + times to approximate the solution. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + + Returns: + pred_prev_sample: Predicted previous sample + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if not self.skip_prk_steps and len(self.ets) < 3: + raise ValueError( + f"{self.__class__} can only be run AFTER scheduler has been run " + "in 'prk' mode for at least 12 iterations " + ) + + prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps + + if self.counter != 1: + self.ets = self.ets[-3:] + self.ets.append(model_output) + else: + prev_timestep = timestep + timestep = timestep + self.num_train_timesteps // self.num_inference_steps + + if len(self.ets) == 1 and self.counter == 0: + model_output = model_output + self.cur_sample = sample + elif len(self.ets) == 1 and self.counter == 1: + model_output = (model_output + self.ets[-1]) / 2 + sample = self.cur_sample + self.cur_sample = torch.Tensor() + elif len(self.ets) == 2: + model_output = (3 * self.ets[-1] - self.ets[-2]) / 2 + elif len(self.ets) == 3: + model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 + else: + model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) + + prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) + self.counter += 1 + + return prev_sample + + def _get_prev_sample(self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor): + # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf + # this function computes x_(t−δ) using the formula of (9) + # Note that x_t needs to be added to both sides of the equation + + # Notation ( -> + # alpha_prod_t -> α_t + # alpha_prod_t_prev -> α_(t−δ) + # beta_prod_t -> (1 - α_t) + # beta_prod_t_prev -> (1 - α_(t−δ)) + # sample -> x_t + # model_output -> e_θ(x_t, t) + # prev_sample -> x_(t−δ) + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + if self.prediction_type == PNDMPredictionType.V_PREDICTION: + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # corresponds to (α_(t−δ) - α_t) divided by + # denominator of x_t in formula (9) and plus 1 + # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = + # sqrt(α_(t−δ)) / sqrt(α_t)) + sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) + + # corresponds to denominator of e_θ(x_t, t) in formula (9) + model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( + alpha_prod_t * beta_prod_t * alpha_prod_t_prev + ) ** (0.5) + + # full formula (9) + prev_sample = ( + sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff + ) + + return prev_sample diff --git a/monai/networks/schedulers/scheduler.py b/monai/networks/schedulers/scheduler.py new file mode 100644 index 0000000000..17bb526abc --- /dev/null +++ b/monai/networks/schedulers/scheduler.py @@ -0,0 +1,203 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + + +from __future__ import annotations + +import torch +import torch.nn as nn + +from monai.utils import ComponentStore, unsqueeze_right + +NoiseSchedules = ComponentStore("NoiseSchedules", "Functions to generate noise schedules") + + +@NoiseSchedules.add_def("linear_beta", "Linear beta schedule") +def _linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): + """ + Linear beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + + Returns: + betas: beta schedule tensor + """ + return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + + +@NoiseSchedules.add_def("scaled_linear_beta", "Scaled linear beta schedule") +def _scaled_linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): + """ + Scaled linear beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + + Returns: + betas: beta schedule tensor + """ + return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + +@NoiseSchedules.add_def("sigmoid_beta", "Sigmoid beta schedule") +def _sigmoid_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2, sig_range: float = 6): + """ + Sigmoid beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + sig_range: pos/neg range of sigmoid input, default 6 + + Returns: + betas: beta schedule tensor + """ + betas = torch.linspace(-sig_range, sig_range, num_train_timesteps) + return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + + +@NoiseSchedules.add_def("cosine", "Cosine schedule") +def _cosine_beta(num_train_timesteps: int, s: float = 8e-3): + """ + Cosine noise schedule, see https://arxiv.org/abs/2102.09672 + + Args: + num_train_timesteps: number of timesteps + s: smoothing factor, default 8e-3 (see referenced paper) + + Returns: + (betas, alphas, alpha_cumprod) values + """ + x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1) + alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod /= alphas_cumprod[0].item() + alphas = torch.clip(alphas_cumprod[1:] / alphas_cumprod[:-1], 0.0001, 0.9999) + betas = 1.0 - alphas + return betas, alphas, alphas_cumprod[:-1] + + +class Scheduler(nn.Module): + """ + Base class for other schedulers based on a noise schedule function. + + This class is meant as the base for other schedulers which implement their own way of sampling or stepping. Here + the class defines beta, alpha, and alpha_cumprod values from a noise schedule function named with `schedule`, + which is the name of a component in NoiseSchedules. These components must all be callables which return either + the beta schedule alone or a triple containing (betas, alphas, alphas_cumprod) values. New schedule functions + can be provided by using the NoiseSchedules.add_def, for example: + + .. code-block:: python + + from monai.networks.schedulers import NoiseSchedules, DDPMScheduler + + @NoiseSchedules.add_def("my_beta_schedule", "Some description of your function") + def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): + return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + + scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="my_beta_schedule") + + All such functions should have an initial positional integer argument `num_train_timesteps` stating the number of + timesteps the schedule is for, otherwise any other arguments can be given which will be passed by keyword through + the constructor's `schedule_args` value. To see what noise functions are available, print the object NoiseSchedules + to get a listing of stored objects with their docstring descriptions. + + Note: in previous versions of the schedulers the argument `schedule_beta` was used to state the beta schedule + type, this now replaced with `schedule` and most names used with the previous argument now have "_beta" appended + to them, eg. 'schedule_beta="linear"' -> 'schedule="linear_beta"'. The `beta_start` and `beta_end` arguments are + still used for some schedules but these are provided as keyword arguments now. + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, + a named function returning the beta tensor or (betas, alphas, alphas_cumprod) triple + schedule_args: arguments to pass to the schedule function + """ + + def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear_beta", **schedule_args) -> None: + super().__init__() + schedule_args["num_train_timesteps"] = num_train_timesteps + noise_sched = NoiseSchedules[schedule](**schedule_args) + + # set betas, alphas, alphas_cumprod based off return value from noise function + if isinstance(noise_sched, tuple): + self.betas, self.alphas, self.alphas_cumprod = noise_sched + else: + self.betas = noise_sched + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + self.num_train_timesteps = num_train_timesteps + self.one = torch.tensor(1.0) + + # settable values + self.num_inference_steps: int | None = None + self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1) + + def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + """ + Add noise to the original samples. + + Args: + original_samples: original samples + noise: noise to add to samples + timesteps: timesteps tensor indicating the timestep to be computed for each sample. + + Returns: + noisy_samples: sample with added noise + """ + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_cumprod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, original_samples.ndim) + sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right( + (1 - self.alphas_cumprod[timesteps]) ** 0.5, original_samples.ndim + ) + + noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, sample.ndim) + sqrt_one_minus_alpha_prod = unsqueeze_right((1 - self.alphas_cumprod[timesteps]) ** 0.5, sample.ndim) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity diff --git a/monai/utils/misc.py b/monai/utils/misc.py index d6ff370f69..4f2501a7ee 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -890,11 +890,11 @@ def is_sqrt(num: Sequence[int] | int) -> bool: return ensure_tuple(ret) == num -def unsqueeze_right(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: +def unsqueeze_right(arr: torch.Tensor, ndim: int) -> torch.Tensor: """Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(...,) + (None,) * (ndim - arr.ndim)] -def unsqueeze_left(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: +def unsqueeze_left(arr: torch.Tensor, ndim: int) -> torch.Tensor: """Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(None,) * (ndim - arr.ndim)] diff --git a/tests/test_scheduler_ddim.py b/tests/test_scheduler_ddim.py new file mode 100644 index 0000000000..1a8f8cab67 --- /dev/null +++ b/tests/test_scheduler_ddim.py @@ -0,0 +1,83 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.schedulers import DDIMScheduler +from tests.utils import assert_allclose + +TEST_2D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) + +TEST_3D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) + +TEST_CASES = TEST_2D_CASE + TEST_3D_CASE + +TEST_FULl_LOOP = [ + [{"schedule": "linear_beta"}, (1, 1, 2, 2), torch.Tensor([[[[-0.9579, -0.6457], [0.4684, -0.9694]]]])] +] + + +class TestDDPMScheduler(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_add_noise(self, input_param, input_shape, expected_shape): + scheduler = DDIMScheduler(**input_param) + scheduler.set_timesteps(num_inference_steps=100) + original_sample = torch.zeros(input_shape) + noise = torch.randn_like(original_sample) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() + + noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) + self.assertEqual(noisy.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + def test_step_shape(self, input_param, input_shape, expected_shape): + scheduler = DDIMScheduler(**input_param) + scheduler.set_timesteps(num_inference_steps=100) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, expected_shape) + self.assertEqual(output_step[1].shape, expected_shape) + + @parameterized.expand(TEST_FULl_LOOP) + def test_full_timestep_loop(self, input_param, input_shape, expected_output): + scheduler = DDIMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3) + + def test_set_timesteps(self): + scheduler = DDIMScheduler(num_train_timesteps=1000) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 100) + self.assertEqual(len(scheduler.timesteps), 100) + + def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): + scheduler = DDIMScheduler(num_train_timesteps=1000) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=2000) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_scheduler_ddpm.py b/tests/test_scheduler_ddpm.py new file mode 100644 index 0000000000..f0447aded2 --- /dev/null +++ b/tests/test_scheduler_ddpm.py @@ -0,0 +1,104 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.schedulers import DDPMScheduler +from tests.utils import assert_allclose + +TEST_2D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + for variance_type in ["fixed_small", "fixed_large"]: + TEST_2D_CASE.append( + [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16), (2, 6, 16, 16)] + ) + +TEST_3D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + for variance_type in ["fixed_small", "fixed_large"]: + TEST_3D_CASE.append( + [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)] + ) + +TEST_CASES = TEST_2D_CASE + TEST_3D_CASE + +TEST_FULl_LOOP = [ + [{"schedule": "linear_beta"}, (1, 1, 2, 2), torch.Tensor([[[[-1.0153, -0.3218], [0.8454, -0.7870]]]])] +] + + +class TestDDPMScheduler(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_add_noise(self, input_param, input_shape, expected_shape): + scheduler = DDPMScheduler(**input_param) + original_sample = torch.zeros(input_shape) + noise = torch.randn_like(original_sample) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() + + noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) + self.assertEqual(noisy.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + def test_step_shape(self, input_param, input_shape, expected_shape): + scheduler = DDPMScheduler(**input_param) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, expected_shape) + self.assertEqual(output_step[1].shape, expected_shape) + + @parameterized.expand(TEST_FULl_LOOP) + def test_full_timestep_loop(self, input_param, input_shape, expected_output): + scheduler = DDPMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3) + + @parameterized.expand(TEST_CASES) + def test_get_velocity_shape(self, input_param, input_shape, expected_shape): + scheduler = DDPMScheduler(**input_param) + sample = torch.randn(input_shape) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],)).long() + velocity = scheduler.get_velocity(sample=sample, noise=sample, timesteps=timesteps) + self.assertEqual(velocity.shape, expected_shape) + + def test_step_learned(self): + for variance_type in ["learned", "learned_range"]: + scheduler = DDPMScheduler(variance_type=variance_type) + model_output = torch.randn(2, 6, 16, 16) + sample = torch.randn(2, 3, 16, 16) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, sample.shape) + self.assertEqual(output_step[1].shape, sample.shape) + + def test_set_timesteps(self): + scheduler = DDPMScheduler(num_train_timesteps=1000) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 100) + self.assertEqual(len(scheduler.timesteps), 100) + + def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): + scheduler = DDPMScheduler(num_train_timesteps=1000) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=2000) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_scheduler_pndm.py b/tests/test_scheduler_pndm.py new file mode 100644 index 0000000000..69e5e403f5 --- /dev/null +++ b/tests/test_scheduler_pndm.py @@ -0,0 +1,108 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.schedulers import PNDMScheduler +from tests.utils import assert_allclose + +TEST_2D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) + +TEST_3D_CASE = [] +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) + +TEST_CASES = TEST_2D_CASE + TEST_3D_CASE + +TEST_FULl_LOOP = [ + [ + {"schedule": "linear_beta"}, + (1, 1, 2, 2), + torch.Tensor([[[[-2123055.2500, -459014.2812], [2863438.0000, -1263401.7500]]]]), + ] +] + + +class TestDDPMScheduler(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_add_noise(self, input_param, input_shape, expected_shape): + scheduler = PNDMScheduler(**input_param) + original_sample = torch.zeros(input_shape) + noise = torch.randn_like(original_sample) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() + noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) + self.assertEqual(noisy.shape, expected_shape) + + @parameterized.expand(TEST_CASES) + def test_step_shape(self, input_param, input_shape, expected_shape): + scheduler = PNDMScheduler(**input_param) + scheduler.set_timesteps(600) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) + self.assertEqual(output_step[0].shape, expected_shape) + self.assertEqual(output_step[1], None) + + @parameterized.expand(TEST_FULl_LOOP) + def test_full_timestep_loop(self, input_param, input_shape, expected_output): + scheduler = PNDMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3) + + @parameterized.expand(TEST_FULl_LOOP) + def test_timestep_two_loops(self, input_param, input_shape, expected_output): + scheduler = PNDMScheduler(**input_param) + scheduler.set_timesteps(50) + torch.manual_seed(42) + model_output = torch.randn(input_shape) + sample = torch.randn(input_shape) + for t in range(50): + sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample) + torch.manual_seed(42) + model_output2 = torch.randn(input_shape) + sample2 = torch.randn(input_shape) + scheduler.set_timesteps(50) + for t in range(50): + sample2, _ = scheduler.step(model_output=model_output2, timestep=t, sample=sample2) + assert_allclose(sample, sample2, rtol=1e-3, atol=1e-3) + + def test_set_timesteps(self): + scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=True) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 100) + self.assertEqual(len(scheduler.timesteps), 100) + + def test_set_timesteps_prk(self): + scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=False) + scheduler.set_timesteps(num_inference_steps=100) + self.assertEqual(scheduler.num_inference_steps, 109) + self.assertEqual(len(scheduler.timesteps), 109) + + def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): + scheduler = PNDMScheduler(num_train_timesteps=1000) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=2000) + + +if __name__ == "__main__": + unittest.main() From 2a8104d6e67daf9fcdbca36068826e5b8ebcb48d Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 5 Jan 2024 06:52:15 +0000 Subject: [PATCH 35/67] 6676 port diffusion schedulers (#7364) This is an update to PR https://github.com/Project-MONAI/MONAI/pull/7332 - I addressed the comments but failed to push the changes before it was merged! Changes are very minor. ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/networks/schedulers/ddim.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py index ec47ff8dc6..78e3cc2a0c 100644 --- a/monai/networks/schedulers/ddim.py +++ b/monai/networks/schedulers/ddim.py @@ -34,23 +34,10 @@ import numpy as np import torch -from monai.utils import StrEnum - +from .ddpm import DDPMPredictionType from .scheduler import Scheduler - -class DDIMPredictionType(StrEnum): - """ - Set of valid prediction type names for the DDIM scheduler's `prediction_type` argument. - - epsilon: predicting the noise of the diffusion process - sample: directly predicting the noisy sample - v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf - """ - - EPSILON = "epsilon" - SAMPLE = "sample" - V_PREDICTION = "v_prediction" +DDIMPredictionType = DDPMPredictionType class DDIMScheduler(Scheduler): @@ -126,6 +113,13 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N self.num_inference_steps = num_inference_steps step_ratio = self.num_train_timesteps // self.num_inference_steps + if self.steps_offset >= step_ratio: + raise ValueError( + f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to " + f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed" + f" the max train timestep." + ) + # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) @@ -159,7 +153,6 @@ def step( timestep: current discrete timestep in the diffusion chain. sample: current instance of sample being created by diffusion process. eta: weight of noise for added noise in diffusion step. - predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. generator: random number generator. Returns: From cc17f72a64f8edfbc46194904d69efe52b79b98d Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 8 Jan 2024 14:56:58 +0000 Subject: [PATCH 36/67] Adds ordering util (#7369) Towards #6676 . ### Description This ordering util got missed out my previous PR for the Generative utils. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham --- docs/source/utils.rst | 5 + monai/utils/ordering.py | 207 ++++++++++++++++++++++++++ tests/test_ordering.py | 318 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 530 insertions(+) create mode 100644 monai/utils/ordering.py create mode 100644 tests/test_ordering.py diff --git a/docs/source/utils.rst b/docs/source/utils.rst index 527247799f..fef671e1f8 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -81,3 +81,8 @@ Component store --------------- .. autoclass:: monai.utils.component_store.ComponentStore :members: + +Ordering +-------- +.. automodule:: monai.utils.ordering + :members: diff --git a/monai/utils/ordering.py b/monai/utils/ordering.py new file mode 100644 index 0000000000..1be61f98ab --- /dev/null +++ b/monai/utils/ordering.py @@ -0,0 +1,207 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import numpy as np + +from monai.utils.enums import OrderingTransformations, OrderingType + + +class Ordering: + """ + Ordering class that projects a 2D or 3D image into a 1D sequence. It also allows the image to be transformed with + one of the following transformations: + Reflection (see np.flip for more details). + Transposition (see np.transpose for more details). + 90-degree rotation (see np.rot90 for more details). + + The transformations are applied in the order specified by the transformation_order parameter. + + Args: + ordering_type: The ordering type. One of the following: + - 'raster_scan': The image is projected into a 1D sequence by scanning the image from left to right and from + top to bottom. Also called a row major ordering. + - 's_curve': The image is projected into a 1D sequence by scanning the image in a circular snake like + pattern from top left towards right gowing in a spiral towards the center. + - random': The image is projected into a 1D sequence by randomly shuffling the image. + spatial_dims: The number of spatial dimensions of the image. + dimensions: The dimensions of the image. + reflected_spatial_dims: A tuple of booleans indicating whether to reflect the image along each spatial dimension. + transpositions_axes: A tuple of tuples indicating the axes to transpose the image along. + rot90_axes: A tuple of tuples indicating the axes to rotate the image along. + transformation_order: The order in which to apply the transformations. + """ + + def __init__( + self, + ordering_type: str, + spatial_dims: int, + dimensions: tuple[int, int, int] | tuple[int, int, int, int], + reflected_spatial_dims: tuple[bool, bool] | None = None, + transpositions_axes: tuple[tuple[int, int], ...] | tuple[tuple[int, int, int], ...] | None = None, + rot90_axes: tuple[tuple[int, int], ...] | None = None, + transformation_order: tuple[str, ...] = ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + ) -> None: + super().__init__() + self.ordering_type = ordering_type + + if self.ordering_type not in list(OrderingType): + raise ValueError( + f"ordering_type must be one of the following {list(OrderingType)}, but got {self.ordering_type}." + ) + + self.spatial_dims = spatial_dims + self.dimensions = dimensions + + if len(dimensions) != self.spatial_dims + 1: + raise ValueError(f"dimensions must be of length {self.spatial_dims + 1}, but got {len(dimensions)}.") + + self.reflected_spatial_dims = reflected_spatial_dims + self.transpositions_axes = transpositions_axes + self.rot90_axes = rot90_axes + if len(set(transformation_order)) != len(transformation_order): + raise ValueError(f"No duplicates are allowed. Received {transformation_order}.") + + for transformation in transformation_order: + if transformation not in list(OrderingTransformations): + raise ValueError( + f"Valid transformations are {list(OrderingTransformations)} but received {transformation}." + ) + self.transformation_order = transformation_order + + self.template = self._create_template() + self._sequence_ordering = self._create_ordering() + self._revert_sequence_ordering = np.argsort(self._sequence_ordering) + + def __call__(self, x: np.ndarray) -> np.ndarray: + x = x[self._sequence_ordering] + + return x + + def get_sequence_ordering(self) -> np.ndarray: + return self._sequence_ordering + + def get_revert_sequence_ordering(self) -> np.ndarray: + return self._revert_sequence_ordering + + def _create_ordering(self) -> np.ndarray: + self.template = self._transform_template() + order = self._order_template(template=self.template) + + return order + + def _create_template(self) -> np.ndarray: + spatial_dimensions = self.dimensions[1:] + template = np.arange(np.prod(spatial_dimensions)).reshape(*spatial_dimensions) + + return template + + def _transform_template(self) -> np.ndarray: + for transformation in self.transformation_order: + if transformation == OrderingTransformations.TRANSPOSE.value: + self.template = self._transpose_template(template=self.template) + elif transformation == OrderingTransformations.ROTATE_90.value: + self.template = self._rot90_template(template=self.template) + elif transformation == OrderingTransformations.REFLECT.value: + self.template = self._flip_template(template=self.template) + + return self.template + + def _transpose_template(self, template: np.ndarray) -> np.ndarray: + if self.transpositions_axes is not None: + for axes in self.transpositions_axes: + template = np.transpose(template, axes=axes) + + return template + + def _flip_template(self, template: np.ndarray) -> np.ndarray: + if self.reflected_spatial_dims is not None: + for axis, to_reflect in enumerate(self.reflected_spatial_dims): + template = np.flip(template, axis=axis) if to_reflect else template + + return template + + def _rot90_template(self, template: np.ndarray) -> np.ndarray: + if self.rot90_axes is not None: + for axes in self.rot90_axes: + template = np.rot90(template, axes=axes) + + return template + + def _order_template(self, template: np.ndarray) -> np.ndarray: + depths = None + if self.spatial_dims == 2: + rows, columns = template.shape[0], template.shape[1] + else: + rows, columns, depths = (template.shape[0], template.shape[1], template.shape[2]) + + sequence = eval(f"self.{self.ordering_type}_idx")(rows, columns, depths) + + ordering = np.array([template[tuple(e)] for e in sequence]) + + return ordering + + @staticmethod + def raster_scan_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray: + idx: list[tuple] = [] + + for r in range(rows): + for c in range(cols): + if depths is not None: + for d in range(depths): + idx.append((r, c, d)) + else: + idx.append((r, c)) + + idx_np = np.array(idx) + + return idx_np + + @staticmethod + def s_curve_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray: + idx: list[tuple] = [] + + for r in range(rows): + col_idx = range(cols) if r % 2 == 0 else range(cols - 1, -1, -1) + for c in col_idx: + if depths: + depth_idx = range(depths) if c % 2 == 0 else range(depths - 1, -1, -1) + + for d in depth_idx: + idx.append((r, c, d)) + else: + idx.append((r, c)) + + idx_np = np.array(idx) + + return idx_np + + @staticmethod + def random_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray: + idx: list[tuple] = [] + + for r in range(rows): + for c in range(cols): + if depths: + for d in range(depths): + idx.append((r, c, d)) + else: + idx.append((r, c)) + + idx_np = np.array(idx) + np.random.shuffle(idx_np) + + return idx_np diff --git a/tests/test_ordering.py b/tests/test_ordering.py new file mode 100644 index 0000000000..0c52dba5e5 --- /dev/null +++ b/tests/test_ordering.py @@ -0,0 +1,318 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.utils.enums import OrderingTransformations, OrderingType +from monai.utils.ordering import Ordering + +TEST_2D_NON_RANDOM = [ + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 2, 3], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 3, 2], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [2, 3, 0, 1], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [2, 3, 1, 0], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": ((1, 0),), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 2, 1, 3], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": ((1, 0),), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 2, 3, 1], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [1, 3, 0, 2], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [1, 3, 2, 0], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 2, 3], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 3, 2], + ], +] + +TEST_2D_RANDOM = [ + [ + { + "ordering_type": OrderingType.RANDOM, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [[0, 1, 2, 3], [0, 1, 3, 2]], + ] +] + +TEST_3D = [ + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 3, + "dimensions": (1, 2, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 2, 3, 4, 5, 6, 7], + ] +] + +TEST_ORDERING_TYPE_FAILURE = [ + [ + { + "ordering_type": "hilbert", + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + } + ] +] + +TEST_ORDERING_TRANSFORMATION_FAILURE = [ + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + "flip", + ), + } + ] +] + +TEST_REVERT = [ + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + } + ] +] + + +class TestOrdering(unittest.TestCase): + @parameterized.expand(TEST_2D_NON_RANDOM + TEST_3D) + def test_ordering(self, input_param, expected_sequence_ordering): + ordering = Ordering(**input_param) + self.assertTrue(np.array_equal(ordering.get_sequence_ordering(), expected_sequence_ordering, equal_nan=True)) + + @parameterized.expand(TEST_ORDERING_TYPE_FAILURE) + def test_ordering_type_failure(self, input_param): + with self.assertRaises(ValueError): + Ordering(**input_param) + + @parameterized.expand(TEST_ORDERING_TRANSFORMATION_FAILURE) + def test_ordering_transformation_failure(self, input_param): + with self.assertRaises(ValueError): + Ordering(**input_param) + + @parameterized.expand(TEST_2D_RANDOM) + def test_random(self, input_param, not_in_expected_sequence_ordering): + ordering = Ordering(**input_param) + + not_in = [ + np.array_equal(sequence, ordering.get_sequence_ordering(), equal_nan=True) + for sequence in not_in_expected_sequence_ordering + ] + + self.assertFalse(np.any(not_in)) + + @parameterized.expand(TEST_REVERT) + def test_revert(self, input_param): + sequence = np.random.randint(0, 100, size=input_param["dimensions"]).flatten() + + ordering = Ordering(**input_param) + + reverted_sequence = sequence[ordering.get_sequence_ordering()] + reverted_sequence = reverted_sequence[ordering.get_revert_sequence_ordering()] + + self.assertTrue(np.array_equal(sequence, reverted_sequence, equal_nan=True)) + + +if __name__ == "__main__": + unittest.main() From f377f2ef6fcd330d2d9a14eb23dcaaaac547b49e Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 18 Jan 2024 12:38:46 +0000 Subject: [PATCH 37/67] 6676 port generative inferers (#7379) Part of #6676 . ### Description Adds Inferers to assist with training and sampling from diffusion models and controllers. Also takes the opportunity to make two changes which slipped through the previous PRs: - rename the `num_channels` arg in the spade diffusion unet to `channels` to be consistent with all the other models added from Generative - this slipped through in the networks PR. - add the `Ordering` class to `__init__.py` for easier import ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- docs/source/inferers.rst | 23 + monai/inferers/__init__.py | 5 + monai/inferers/inferer.py | 1280 ++++++++++++++++- monai/networks/nets/diffusion_model_unet.py | 6 +- .../nets/spade_diffusion_model_unet.py | 40 +- monai/utils/__init__.py | 1 + setup.cfg | 14 +- tests/test_controlnet_inferers.py | 1270 ++++++++++++++++ tests/test_diffusion_inferer.py | 226 +++ tests/test_flexible_unet.py | 2 +- tests/test_invertd.py | 12 +- tests/test_latent_diffusion_inferer.py | 796 ++++++++++ tests/test_ordering.py | 29 - .../test_spade_autoencoderkl.py | 0 .../test_spade_diffusion_model_unet.py | 66 +- tests/test_vqvaetransformer_inferer.py | 284 ++++ 16 files changed, 3955 insertions(+), 99 deletions(-) create mode 100644 tests/test_controlnet_inferers.py create mode 100644 tests/test_diffusion_inferer.py create mode 100644 tests/test_latent_diffusion_inferer.py rename test_spade_autoencoderkl.py => tests/test_spade_autoencoderkl.py (100%) rename test_spade_diffusion_model_unet.py => tests/test_spade_diffusion_model_unet.py (92%) create mode 100644 tests/test_vqvaetransformer_inferer.py diff --git a/docs/source/inferers.rst b/docs/source/inferers.rst index 33f9e14d83..326f56e96c 100644 --- a/docs/source/inferers.rst +++ b/docs/source/inferers.rst @@ -49,6 +49,29 @@ Inferers :members: :special-members: __call__ +`DiffusionInferer` +~~~~~~~~~~~~~~~~~~ +.. autoclass:: DiffusionInferer + :members: + :special-members: __call__ + +`LatentDiffusionInferer` +~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: LatentDiffusionInferer + :members: + :special-members: __call__ + +`ControlNetDiffusionInferer` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ControlNetDiffusionInferer + :members: + :special-members: __call__ + +`ControlNetLatentDiffusionInferer` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ControlNetLatentDiffusionInferer + :members: + :special-members: __call__ Splitters --------- diff --git a/monai/inferers/__init__.py b/monai/inferers/__init__.py index 960380bfb8..fc78b9f7c4 100644 --- a/monai/inferers/__init__.py +++ b/monai/inferers/__init__.py @@ -12,13 +12,18 @@ from __future__ import annotations from .inferer import ( + ControlNetDiffusionInferer, + ControlNetLatentDiffusionInferer, + DiffusionInferer, Inferer, + LatentDiffusionInferer, PatchInferer, SaliencyInferer, SimpleInferer, SliceInferer, SlidingWindowInferer, SlidingWindowInfererAdapt, + VQVAETransformerInferer, ) from .merger import AvgMerger, Merger, ZarrAvgMerger from .splitter import SlidingWindowSplitter, Splitter, WSISlidingWindowSplitter diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 0b4199938d..72bcb8fd5a 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -11,24 +11,41 @@ from __future__ import annotations +import math import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence +from functools import partial from pydoc import locate from typing import Any import torch import torch.nn as nn +import torch.nn.functional as F from monai.apps.utils import get_logger +from monai.data import decollate_batch from monai.data.meta_tensor import MetaTensor from monai.data.thread_buffer import ThreadBuffer from monai.inferers.merger import AvgMerger, Merger from monai.inferers.splitter import Splitter from monai.inferers.utils import compute_importance_map, sliding_window_inference -from monai.utils import BlendMode, PatchKeys, PytorchPadMode, ensure_tuple, optional_import +from monai.networks.nets import ( + VQVAE, + AutoencoderKL, + ControlNet, + DecoderOnlyTransformer, + DiffusionModelUNet, + SPADEAutoencoderKL, + SPADEDiffusionModelUNet, +) +from monai.networks.schedulers import Scheduler +from monai.transforms import CenterSpatialCrop, SpatialPad +from monai.utils import BlendMode, Ordering, PatchKeys, PytorchPadMode, ensure_tuple, optional_import from monai.visualize import CAM, GradCAM, GradCAMpp +tqdm, has_tqdm = optional_import("tqdm", name="tqdm") + logger = get_logger(__name__) __all__ = [ @@ -752,3 +769,1264 @@ def network_wrapper( return out return tuple(out_i.unsqueeze(dim=self.spatial_dim + 2) for out_i in out) + + +class DiffusionInferer(Inferer): + """ + DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass + for a training iteration, and sample from the model. + + Args: + scheduler: diffusion scheduler. + """ + + def __init__(self, scheduler: Scheduler) -> None: # type: ignore[override] + super().__init__() + + self.scheduler = scheduler + + def __call__( # type: ignore[override] + self, + inputs: torch.Tensor, + diffusion_model: DiffusionModelUNet, + noise: torch.Tensor, + timesteps: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: Input image to which noise is added. + diffusion_model: diffusion model. + noise: random noise, of the same shape as the input. + timesteps: random timesteps. + condition: Conditioning for network input. + mode: Conditioning mode for the network. + seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be + provided on the forward (for SPADE-like AE or SPADE-like DM) + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + noisy_image: torch.Tensor = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + if mode == "concat": + if condition is None: + raise ValueError("Conditioning is required for concat condition") + else: + noisy_image = torch.cat([noisy_image, condition], dim=1) + condition = None + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) + prediction: torch.Tensor = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition) + + return prediction + + @torch.no_grad() + def sample( + self, + input_noise: torch.Tensor, + diffusion_model: DiffusionModelUNet, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired sample. + diffusion_model: model to sample from. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + if mode == "concat" and conditioning is None: + raise ValueError("Conditioning must be supplied for if condition mode is concat.") + if not scheduler: + scheduler = self.scheduler + image = input_noise + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + for t in progress_bar: + # 1. predict noise model_output + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) + if mode == "concat" and conditioning is not None: + model_input = torch.cat([image, conditioning], dim=1) + model_output = diffusion_model( + model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None + ) + else: + model_output = diffusion_model( + image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning + ) + + # 2. compute previous image: x_t -> x_t-1 + image, _ = scheduler.step(model_output, t, image) + if save_intermediates and t % intermediate_steps == 0: + intermediates.append(image) + if save_intermediates: + return image, intermediates + else: + return image + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + diffusion_model: DiffusionModelUNet, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple = (0, 255), + scaled_input_range: tuple = (0, 1), + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods for an input. + + Args: + inputs: input images, NxCxHxW[xD] + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + + if not scheduler: + scheduler = self.scheduler + if scheduler._get_name() != "DDPMScheduler": + raise NotImplementedError( + f"Likelihood computation is only compatible with DDPMScheduler," + f" you are using {scheduler._get_name()}" + ) + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + if mode == "concat" and conditioning is None: + raise ValueError("Conditioning must be supplied for if condition mode is concat.") + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + noise = torch.randn_like(inputs).to(inputs.device) + total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) + for t in progress_bar: + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + diffusion_model = ( + partial(diffusion_model, seg=seg) + if isinstance(diffusion_model, SPADEDiffusionModelUNet) + else diffusion_model + ) + if mode == "concat" and conditioning is not None: + noisy_image = torch.cat([noisy_image, conditioning], dim=1) + model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None) + else: + model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) + + # get the model's predicted mean, and variance if it is predicted + if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[t] + alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if scheduler.prediction_type == "epsilon": + pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif scheduler.prediction_type == "sample": + pred_original_sample = model_output + elif scheduler.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output + # 3. Clip "predicted x_0" + if scheduler.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t + current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image + + # get the posterior mean and variance + posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + + log_posterior_variance = torch.log(posterior_variance) + log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + + if t == 0: + # compute -log p(x_0|x_1) + kl = -self._get_decoder_log_likelihood( + inputs=inputs, + means=predicted_mean, + log_scales=0.5 * log_predicted_variance, + original_input_range=original_input_range, + scaled_input_range=scaled_input_range, + ) + else: + # compute kl between two normals + kl = 0.5 * ( + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + ) + total_kl += kl.view(kl.shape[0], -1).mean(dim=1) + if save_intermediates: + intermediates.append(kl.cpu()) + + if save_intermediates: + return total_kl, intermediates + else: + return total_kl + + def _approx_standard_normal_cdf(self, x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. Code adapted from https://github.com/openai/improved-diffusion. + """ + + return 0.5 * ( + 1.0 + torch.tanh(torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3))) + ) + + def _get_decoder_log_likelihood( + self, + inputs: torch.Tensor, + means: torch.Tensor, + log_scales: torch.Tensor, + original_input_range: tuple = (0, 255), + scaled_input_range: tuple = (0, 1), + ) -> torch.Tensor: + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. Code adapted from https://github.com/openai/improved-diffusion. + + Args: + input: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + means: the Gaussian mean Tensor. + log_scales: the Gaussian log stddev Tensor. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + """ + if inputs.shape != means.shape: + raise ValueError(f"Inputs and means must have the same shape, got {inputs.shape} and {means.shape}") + bin_width = (scaled_input_range[1] - scaled_input_range[0]) / ( + original_input_range[1] - original_input_range[0] + ) + centered_x = inputs - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + bin_width / 2) + cdf_plus = self._approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - bin_width / 2) + cdf_min = self._approx_standard_normal_cdf(min_in) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + inputs < -0.999, + log_cdf_plus, + torch.where(inputs > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), + ) + return log_probs + + +class LatentDiffusionInferer(DiffusionInferer): + """ + LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can + be used to perform a signal forward pass for a training iteration, and sample from the model. + + Args: + scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents. + scale_factor: scale factor to multiply the values of the latent representation before processing it by the + second stage. + ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape. + autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a + difference between the autoencoder's latent shape and the DM shape. + """ + + def __init__( + self, + scheduler: Scheduler, + scale_factor: float = 1.0, + ldm_latent_shape: list | None = None, + autoencoder_latent_shape: list | None = None, + ) -> None: + super().__init__(scheduler=scheduler) + self.scale_factor = scale_factor + if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None): + raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None, and vice versa.") + self.ldm_latent_shape = ldm_latent_shape + self.autoencoder_latent_shape = autoencoder_latent_shape + if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None: + self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) + self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape) + + def __call__( # type: ignore[override] + self, + inputs: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + noise: torch.Tensor, + timesteps: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted and noise is added. + autoencoder_model: first stage model. + diffusion_model: diffusion model. + noise: random noise, of the same shape as the latent representation. + timesteps: random timesteps. + condition: conditioning for network input. + mode: Conditioning mode for the network. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + with torch.no_grad(): + latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if self.ldm_latent_shape is not None: + latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) + + prediction: torch.Tensor = super().__call__( + inputs=latent, + diffusion_model=diffusion_model, + noise=noise, + timesteps=timesteps, + condition=condition, + mode=mode, + seg=seg, + ) + return prediction + + @torch.no_grad() + def sample( # type: ignore[override] + self, + input_noise: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired latent representation. + autoencoder_model: first stage model. + diffusion_model: model to sample from. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + + if ( + isinstance(autoencoder_model, SPADEAutoencoderKL) + and isinstance(diffusion_model, SPADEDiffusionModelUNet) + and autoencoder_model.decoder.label_nc != diffusion_model.label_nc + ): + raise ValueError( + f"If both autoencoder_model and diffusion_model implement SPADE, the number of semantic" + f"labels for each must be compatible, but got {autoencoder_model.decoder.label_nc} and" + f"{diffusion_model.label_nc}" + ) + + outputs = super().sample( + input_noise=input_noise, + diffusion_model=diffusion_model, + scheduler=scheduler, + save_intermediates=save_intermediates, + intermediate_steps=intermediate_steps, + conditioning=conditioning, + mode=mode, + verbose=verbose, + seg=seg, + ) + + if save_intermediates: + latent, latent_intermediates = outputs + else: + latent = outputs + + if self.autoencoder_latent_shape is not None: + latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] + + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + image = decode(latent / self.scale_factor) + + if save_intermediates: + intermediates = [] + for latent_intermediate in latent_intermediates: + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + intermediates.append(decode(latent_intermediate / self.scale_factor)) + return image, intermediates + + else: + return image + + @torch.no_grad() + def get_likelihood( # type: ignore[override] + self, + inputs: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + autoencoder_model: first stage model. + diffusion_model: model to compute likelihood from + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if self.ldm_latent_shape is not None: + latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) + + outputs = super().get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + scheduler=scheduler, + save_intermediates=save_intermediates, + conditioning=conditioning, + mode=mode, + verbose=verbose, + seg=seg, + ) + + if save_intermediates and resample_latent_likelihoods: + intermediates = outputs[1] + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) + intermediates = [resizer(x) for x in intermediates] + outputs = (outputs[0], intermediates) + return outputs + + +class ControlNetDiffusionInferer(DiffusionInferer): + """ + ControlNetDiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal + forward pass for a training iteration, and sample from the model, supporting ControlNet-based conditioning. + + Args: + scheduler: diffusion scheduler. + """ + + def __init__(self, scheduler: Scheduler) -> None: + Inferer.__init__(self) + self.scheduler = scheduler + + def __call__( # type: ignore[override] + self, + inputs: torch.Tensor, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + noise: torch.Tensor, + timesteps: torch.Tensor, + cn_cond: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: Input image to which noise is added. + diffusion_model: diffusion model. + controlnet: controlnet sub-network. + noise: random noise, of the same shape as the input. + timesteps: random timesteps. + cn_cond: conditioning image for the ControlNet. + condition: Conditioning for network input. + mode: Conditioning mode for the network. + seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be + provided on the forward (for SPADE-like AE or SPADE-like DM) + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond + ) + if mode == "concat" and condition is not None: + noisy_image = torch.cat([noisy_image, condition], dim=1) + condition = None + + diffuse = diffusion_model + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + diffuse = partial(diffusion_model, seg=seg) + + prediction: torch.Tensor = diffuse( + x=noisy_image, + timesteps=timesteps, + context=condition, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + + return prediction + + @torch.no_grad() + def sample( # type: ignore[override] + self, + input_noise: torch.Tensor, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + cn_cond: torch.Tensor, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired sample. + diffusion_model: model to sample from. + controlnet: controlnet sub-network. + cn_cond: conditioning image for the ControlNet. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + + if not scheduler: + scheduler = self.scheduler + image = input_noise + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + for t in progress_bar: + # 1. ControlNet forward + down_block_res_samples, mid_block_res_sample = controlnet( + x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond + ) + # 2. predict noise model_output + diffuse = diffusion_model + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + diffuse = partial(diffusion_model, seg=seg) + + if mode == "concat" and conditioning is not None: + model_input = torch.cat([image, conditioning], dim=1) + model_output = diffuse( + model_input, + timesteps=torch.Tensor((t,)).to(input_noise.device), + context=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + else: + model_output = diffuse( + image, + timesteps=torch.Tensor((t,)).to(input_noise.device), + context=conditioning, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + + # 3. compute previous image: x_t -> x_t-1 + image, _ = scheduler.step(model_output, t, image) + if save_intermediates and t % intermediate_steps == 0: + intermediates.append(image) + if save_intermediates: + return image, intermediates + else: + return image + + @torch.no_grad() + def get_likelihood( # type: ignore[override] + self, + inputs: torch.Tensor, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + cn_cond: torch.Tensor, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple = (0, 255), + scaled_input_range: tuple = (0, 1), + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods for an input. + + Args: + inputs: input images, NxCxHxW[xD] + diffusion_model: model to compute likelihood from + controlnet: controlnet sub-network. + cn_cond: conditioning image for the ControlNet. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + + if not scheduler: + scheduler = self.scheduler + if scheduler._get_name() != "DDPMScheduler": + raise NotImplementedError( + f"Likelihood computation is only compatible with DDPMScheduler," + f" you are using {scheduler._get_name()}" + ) + if mode not in ["crossattn", "concat"]: + raise NotImplementedError(f"{mode} condition is not supported") + if verbose and has_tqdm: + progress_bar = tqdm(scheduler.timesteps) + else: + progress_bar = iter(scheduler.timesteps) + intermediates = [] + noise = torch.randn_like(inputs).to(inputs.device) + total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) + for t in progress_bar: + timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() + noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond + ) + + diffuse = diffusion_model + if isinstance(diffusion_model, SPADEDiffusionModelUNet): + diffuse = partial(diffusion_model, seg=seg) + + if mode == "concat" and conditioning is not None: + noisy_image = torch.cat([noisy_image, conditioning], dim=1) + model_output = diffuse( + noisy_image, + timesteps=timesteps, + context=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + else: + model_output = diffuse( + x=noisy_image, + timesteps=timesteps, + context=conditioning, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + # get the model's predicted mean, and variance if it is predicted + if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[t] + alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if scheduler.prediction_type == "epsilon": + pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif scheduler.prediction_type == "sample": + pred_original_sample = model_output + elif scheduler.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output + # 3. Clip "predicted x_0" + if scheduler.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t + current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image + + # get the posterior mean and variance + posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + + log_posterior_variance = torch.log(posterior_variance) + log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + + if t == 0: + # compute -log p(x_0|x_1) + kl = -super()._get_decoder_log_likelihood( + inputs=inputs, + means=predicted_mean, + log_scales=0.5 * log_predicted_variance, + original_input_range=original_input_range, + scaled_input_range=scaled_input_range, + ) + else: + # compute kl between two normals + kl = 0.5 * ( + -1.0 + + log_predicted_variance + - log_posterior_variance + + torch.exp(log_posterior_variance - log_predicted_variance) + + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance) + ) + total_kl += kl.view(kl.shape[0], -1).mean(dim=1) + if save_intermediates: + intermediates.append(kl.cpu()) + + if save_intermediates: + return total_kl, intermediates + else: + return total_kl + + +class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer): + """ + ControlNetLatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, controlnet, + and a scheduler, and can be used to perform a signal forward pass for a training iteration, and sample from + the model. + + Args: + scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents. + scale_factor: scale factor to multiply the values of the latent representation before processing it by the + second stage. + ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape. + autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a + difference between the autoencoder's latent shape and the DM shape. + """ + + def __init__( + self, + scheduler: Scheduler, + scale_factor: float = 1.0, + ldm_latent_shape: list | None = None, + autoencoder_latent_shape: list | None = None, + ) -> None: + super().__init__(scheduler=scheduler) + self.scale_factor = scale_factor + if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None): + raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None" "and vice versa.") + self.ldm_latent_shape = ldm_latent_shape + self.autoencoder_latent_shape = autoencoder_latent_shape + if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None: + self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape) + self.autoencoder_resizer = CenterSpatialCrop(roi_size=[-1] + self.autoencoder_latent_shape) + + def __call__( # type: ignore[override] + self, + inputs: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + noise: torch.Tensor, + timesteps: torch.Tensor, + cn_cond: torch.Tensor, + condition: torch.Tensor | None = None, + mode: str = "crossattn", + seg: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted and noise is added. + autoencoder_model: first stage model. + diffusion_model: diffusion model. + controlnet: instance of ControlNet model + noise: random noise, of the same shape as the latent representation. + timesteps: random timesteps. + cn_cond: conditioning tensor for the ControlNet network + condition: conditioning for network input. + mode: Conditioning mode for the network. + seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + """ + with torch.no_grad(): + latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if self.ldm_latent_shape is not None: + latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) + + if cn_cond.shape[2:] != latent.shape[2:]: + cn_cond = F.interpolate(cn_cond, latent.shape[2:]) + + prediction = super().__call__( + inputs=latent, + diffusion_model=diffusion_model, + controlnet=controlnet, + noise=noise, + timesteps=timesteps, + cn_cond=cn_cond, + condition=condition, + mode=mode, + seg=seg, + ) + + return prediction + + @torch.no_grad() + def sample( # type: ignore[override] + self, + input_noise: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + cn_cond: torch.Tensor, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + intermediate_steps: int | None = 100, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + verbose: bool = True, + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_noise: random noise, of the same shape as the desired latent representation. + autoencoder_model: first stage model. + diffusion_model: model to sample from. + controlnet: instance of ControlNet model. + cn_cond: conditioning tensor for the ControlNet network. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler. + save_intermediates: whether to return intermediates along the sampling change + intermediate_steps: if save_intermediates is True, saves every n steps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + verbose: if true, prints the progression bar of the sampling process. + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + + if ( + isinstance(autoencoder_model, SPADEAutoencoderKL) + and isinstance(diffusion_model, SPADEDiffusionModelUNet) + and autoencoder_model.decoder.label_nc != diffusion_model.label_nc + ): + raise ValueError( + "If both autoencoder_model and diffusion_model implement SPADE, the number of semantic" + "labels for each must be compatible. Got {autoencoder_model.decoder.label_nc} and {diffusion_model.label_nc}" + ) + + if cn_cond.shape[2:] != input_noise.shape[2:]: + cn_cond = F.interpolate(cn_cond, input_noise.shape[2:]) + + outputs = super().sample( + input_noise=input_noise, + diffusion_model=diffusion_model, + controlnet=controlnet, + cn_cond=cn_cond, + scheduler=scheduler, + save_intermediates=save_intermediates, + intermediate_steps=intermediate_steps, + conditioning=conditioning, + mode=mode, + verbose=verbose, + seg=seg, + ) + + if save_intermediates: + latent, latent_intermediates = outputs + else: + latent = outputs + + if self.autoencoder_latent_shape is not None: + latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + ] + + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + + image = decode(latent / self.scale_factor) + + if save_intermediates: + intermediates = [] + for latent_intermediate in latent_intermediates: + decode = autoencoder_model.decode_stage_2_outputs + if isinstance(autoencoder_model, SPADEAutoencoderKL): + decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) + intermediates.append(decode(latent_intermediate / self.scale_factor)) + return image, intermediates + + else: + return image + + @torch.no_grad() + def get_likelihood( # type: ignore[override] + self, + inputs: torch.Tensor, + autoencoder_model: AutoencoderKL | VQVAE, + diffusion_model: DiffusionModelUNet, + controlnet: ControlNet, + cn_cond: torch.Tensor, + scheduler: Scheduler | None = None, + save_intermediates: bool | None = False, + conditioning: torch.Tensor | None = None, + mode: str = "crossattn", + original_input_range: tuple | None = (0, 255), + scaled_input_range: tuple | None = (0, 1), + verbose: bool = True, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + seg: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Computes the log-likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + autoencoder_model: first stage model. + diffusion_model: model to compute likelihood from + controlnet: instance of ControlNet model. + cn_cond: conditioning tensor for the ControlNet network. + scheduler: diffusion scheduler. If none provided will use the class attribute scheduler + save_intermediates: save the intermediate spatial KL maps + conditioning: Conditioning for network input. + mode: Conditioning mode for the network. + original_input_range: the [min,max] intensity range of the input data before any scaling was applied. + scaled_input_range: the [min,max] intensity range of the input data after scaling. + verbose: if true, prints the progression bar of the sampling process. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; + seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model + is instance of SPADEAutoencoderKL, segmentation must be provided. + """ + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + + latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + if cn_cond.shape[2:] != latents.shape[2:]: + cn_cond = F.interpolate(cn_cond, latents.shape[2:]) + + if self.ldm_latent_shape is not None: + latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) + + outputs = super().get_likelihood( + inputs=latents, + diffusion_model=diffusion_model, + controlnet=controlnet, + cn_cond=cn_cond, + scheduler=scheduler, + save_intermediates=save_intermediates, + conditioning=conditioning, + mode=mode, + verbose=verbose, + seg=seg, + ) + + if save_intermediates and resample_latent_likelihoods: + intermediates = outputs[1] + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) + intermediates = [resizer(x) for x in intermediates] + outputs = (outputs[0], intermediates) + return outputs + + +class VQVAETransformerInferer(nn.Module): + """ + Class to perform inference with a VQVAE + Transformer model. + """ + + def __init__(self) -> None: + Inferer.__init__(self) + + def __call__( + self, + inputs: torch.Tensor, + vqvae_model: VQVAE, + transformer_model: DecoderOnlyTransformer, + ordering: Ordering, + condition: torch.Tensor | None = None, + return_latent: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, tuple]: + """ + Implements the forward pass for a supervised training iteration. + + Args: + inputs: input image to which the latent representation will be extracted. + vqvae_model: first stage model. + transformer_model: autoregressive transformer model. + ordering: ordering of the quantised latent representation. + return_latent: also return latent sequence and spatial dim of the latent. + condition: conditioning for network input. + """ + with torch.no_grad(): + latent = vqvae_model.index_quantize(inputs) + + latent_spatial_dim = tuple(latent.shape[1:]) + latent = latent.reshape(latent.shape[0], -1) + latent = latent[:, ordering.get_sequence_ordering()] + + # get the targets for the loss + target = latent.clone() + # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. + # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. + latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) + # crop the last token as we do not need the probability of the token that follows it + latent = latent[:, :-1] + latent = latent.long() + + # train on a part of the sequence if it is longer than max_seq_length + seq_len = latent.shape[1] + max_seq_len = transformer_model.max_seq_len + if max_seq_len < seq_len: + start = int(torch.randint(low=0, high=seq_len + 1 - max_seq_len, size=(1,)).item()) + else: + start = 0 + prediction: torch.Tensor = transformer_model(x=latent[:, start : start + max_seq_len], context=condition) + if return_latent: + return prediction, target[:, start : start + max_seq_len], latent_spatial_dim + else: + return prediction + + @torch.no_grad() + def sample( + self, + latent_spatial_dim: tuple[int, int, int] | tuple[int, int], + starting_tokens: torch.Tensor, + vqvae_model: VQVAE, + transformer_model: DecoderOnlyTransformer, + ordering: Ordering, + conditioning: torch.Tensor | None = None, + temperature: float = 1.0, + top_k: int | None = None, + verbose: bool = True, + ) -> torch.Tensor: + """ + Sampling function for the VQVAE + Transformer model. + + Args: + latent_spatial_dim: shape of the sampled image. + starting_tokens: starting tokens for the sampling. It must be vqvae_model.num_embeddings value. + vqvae_model: first stage model. + transformer_model: model to sample from. + conditioning: Conditioning for network input. + temperature: temperature for sampling. + top_k: top k sampling. + verbose: if true, prints the progression bar of the sampling process. + """ + seq_len = math.prod(latent_spatial_dim) + + if verbose and has_tqdm: + progress_bar = tqdm(range(seq_len)) + else: + progress_bar = iter(range(seq_len)) + + latent_seq = starting_tokens.long() + for _ in progress_bar: + # if the sequence context is growing too long we must crop it at block_size + if latent_seq.size(1) <= transformer_model.max_seq_len: + idx_cond = latent_seq + else: + idx_cond = latent_seq[:, -transformer_model.max_seq_len :] + + # forward the model to get the logits for the index in the sequence + logits = transformer_model(x=idx_cond, context=conditioning) + # pluck the logits at the final step and scale by desired temperature + logits = logits[:, -1, :] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float("Inf") + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # remove the chance to be sampled the BOS token + probs[:, vqvae_model.num_embeddings] = 0 + # sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) + # append sampled index to the running sequence and continue + latent_seq = torch.cat((latent_seq, idx_next), dim=1) + + latent_seq = latent_seq[:, 1:] + latent_seq = latent_seq[:, ordering.get_revert_sequence_ordering()] + latent = latent_seq.reshape((starting_tokens.shape[0],) + latent_spatial_dim) + + return vqvae_model.decode_samples(latent) + + @torch.no_grad() + def get_likelihood( + self, + inputs: torch.Tensor, + vqvae_model: VQVAE, + transformer_model: DecoderOnlyTransformer, + ordering: Ordering, + condition: torch.Tensor | None = None, + resample_latent_likelihoods: bool = False, + resample_interpolation_mode: str = "nearest", + verbose: bool = False, + ) -> torch.Tensor: + """ + Computes the log-likelihoods of the latent representations of the input. + + Args: + inputs: input images, NxCxHxW[xD] + vqvae_model: first stage model. + transformer_model: autoregressive transformer model. + ordering: ordering of the quantised latent representation. + condition: conditioning for network input. + resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial + dimension as the input images. + resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear', + or 'trilinear; + verbose: if true, prints the progression bar of the sampling process. + + """ + if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): + raise ValueError( + f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" + ) + + with torch.no_grad(): + latent = vqvae_model.index_quantize(inputs) + + latent_spatial_dim = tuple(latent.shape[1:]) + latent = latent.reshape(latent.shape[0], -1) + latent = latent[:, ordering.get_sequence_ordering()] + seq_len = math.prod(latent_spatial_dim) + + # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token. + # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens. + latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings) + latent = latent.long() + + # get the first batch, up to max_seq_length, efficiently + logits = transformer_model(x=latent[:, : transformer_model.max_seq_len], context=condition) + probs = F.softmax(logits, dim=-1) + # target token for each set of logits is the next token along + target = latent[:, 1:] + probs = torch.gather(probs, 2, target[:, : transformer_model.max_seq_len].unsqueeze(2)).squeeze(2) + + # if we have not covered the full sequence we continue with inefficient looping + if probs.shape[1] < target.shape[1]: + if verbose and has_tqdm: + progress_bar = tqdm(range(transformer_model.max_seq_len, seq_len)) + else: + progress_bar = iter(range(transformer_model.max_seq_len, seq_len)) + + for i in progress_bar: + idx_cond = latent[:, i + 1 - transformer_model.max_seq_len : i + 1] + # forward the model to get the logits for the index in the sequence + logits = transformer_model(x=idx_cond, context=condition) + # pluck the logits at the final step + logits = logits[:, -1, :] + # apply softmax to convert logits to (normalized) probabilities + p = F.softmax(logits, dim=-1) + # select correct values and append + p = torch.gather(p, 1, target[:, i].unsqueeze(1)) + + probs = torch.cat((probs, p), dim=1) + + # convert to log-likelihood + probs = torch.log(probs) + + # reshape + probs = probs[:, ordering.get_revert_sequence_ordering()] + probs_reshaped = probs.reshape((inputs.shape[0],) + latent_spatial_dim) + if resample_latent_likelihoods: + resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode) + probs_reshaped = resizer(probs_reshaped[:, None, ...]) + + return probs_reshaped diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 1532215c70..0441cc9cfe 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -430,7 +430,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: batch, channel, height, width, depth = x.shape # norm - x = self.norm(x) + x = self.norm(x.contiguous()) if self.spatial_dims == 2: x = x.view(batch, channel, height * width).transpose(1, 2) @@ -682,7 +682,7 @@ def __init__( ) def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: - h = x + h = x.contiguous() h = self.norm1(h) h = self.nonlinearity(h) @@ -1957,7 +1957,7 @@ def forward( h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) # 7. output block - output: torch.Tensor = self.out(h) + output: torch.Tensor = self.out(h.contiguous()) return output diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py index d53327100e..bffc9c5465 100644 --- a/monai/networks/nets/spade_diffusion_model_unet.py +++ b/monai/networks/nets/spade_diffusion_model_unet.py @@ -618,7 +618,7 @@ class SPADEDiffusionModelUNet(nn.Module): out_channels: number of output channels. label_nc: number of semantic channels for SPADE normalisation. num_res_blocks: number of residual blocks (see ResnetBlock) per level. - num_channels: tuple of block output channels. + channels: tuple of block output channels. attention_levels: list of levels to add attention. norm_num_groups: number of groups for the normalization. norm_eps: epsilon for the normalization. @@ -641,7 +641,7 @@ def __init__( out_channels: int, label_nc: int, num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), - num_channels: Sequence[int] = (32, 64, 64, 64), + channels: Sequence[int] = (32, 64, 64, 64), attention_levels: Sequence[bool] = (False, False, True, True), norm_num_groups: int = 32, norm_eps: float = 1e-6, @@ -667,10 +667,10 @@ def __init__( ) # All number of channels should be multiple of num_groups - if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + if any((out_channel % norm_num_groups) != 0 for out_channel in channels): raise ValueError("SPADEDiffusionModelUNet expects all num_channels being multiple of norm_num_groups") - if len(num_channels) != len(attention_levels): + if len(channels) != len(attention_levels): raise ValueError("SPADEDiffusionModelUNet expects num_channels being same size of attention_levels") if isinstance(num_head_channels, int): @@ -683,9 +683,9 @@ def __init__( ) if isinstance(num_res_blocks, int): - num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) - if len(num_res_blocks) != len(num_channels): + if len(num_res_blocks) != len(channels): raise ValueError( "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " "`num_channels`." @@ -700,7 +700,7 @@ def __init__( ) self.in_channels = in_channels - self.block_out_channels = num_channels + self.block_out_channels = channels self.out_channels = out_channels self.num_res_blocks = num_res_blocks self.attention_levels = attention_levels @@ -712,7 +712,7 @@ def __init__( self.conv_in = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, - out_channels=num_channels[0], + out_channels=channels[0], strides=1, kernel_size=3, padding=1, @@ -720,9 +720,9 @@ def __init__( ) # time - time_embed_dim = num_channels[0] * 4 + time_embed_dim = channels[0] * 4 self.time_embed = nn.Sequential( - nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) ) # class embedding @@ -732,11 +732,11 @@ def __init__( # down self.down_blocks = nn.ModuleList([]) - output_channel = num_channels[0] - for i in range(len(num_channels)): + output_channel = channels[0] + for i in range(len(channels)): input_channel = output_channel - output_channel = num_channels[i] - is_final_block = i == len(num_channels) - 1 + output_channel = channels[i] + is_final_block = i == len(channels) - 1 down_block = get_down_block( spatial_dims=spatial_dims, @@ -762,7 +762,7 @@ def __init__( # mid self.middle_block = get_mid_block( spatial_dims=spatial_dims, - in_channels=num_channels[-1], + in_channels=channels[-1], temb_channels=time_embed_dim, norm_num_groups=norm_num_groups, norm_eps=norm_eps, @@ -776,7 +776,7 @@ def __init__( # up self.up_blocks = nn.ModuleList([]) - reversed_block_out_channels = list(reversed(num_channels)) + reversed_block_out_channels = list(reversed(channels)) reversed_num_res_blocks = list(reversed(num_res_blocks)) reversed_attention_levels = list(reversed(attention_levels)) reversed_num_head_channels = list(reversed(num_head_channels)) @@ -784,9 +784,9 @@ def __init__( for i in range(len(reversed_block_out_channels)): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)] + input_channel = reversed_block_out_channels[min(i + 1, len(channels) - 1)] - is_final_block = i == len(num_channels) - 1 + is_final_block = i == len(channels) - 1 up_block = get_spade_up_block( spatial_dims=spatial_dims, @@ -814,12 +814,12 @@ def __init__( # out self.out = nn.Sequential( - nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True), + nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[0], eps=norm_eps, affine=True), nn.SiLU(), zero_module( Convolution( spatial_dims=spatial_dims, - in_channels=num_channels[0], + in_channels=channels[0], out_channels=out_channels, strides=1, kernel_size=3, diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 2c32eb2cf4..03fa1ceed1 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -126,6 +126,7 @@ version_leq, ) from .nvtx import Range +from .ordering import Ordering from .profiling import ( PerfContext, ProfileHandler, diff --git a/setup.cfg b/setup.cfg index 123da68dfa..0069214de3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,7 +52,7 @@ all = scipy>=1.7.1 pillow tensorboard - gdown>=4.4.0 + gdown==4.6.3 pytorch-ignite==0.4.11 torchvision itk>=5.2 @@ -60,12 +60,12 @@ all = lmdb psutil cucim>=23.2.0 - openslide-python==1.1.2 + openslide-python tifffile imagecodecs pandas einops - transformers<4.22 + transformers<4.22; python_version <= '3.10' mlflow>=1.28.0 clearml>=1.10.0rc0 matplotlib @@ -97,7 +97,7 @@ pillow = tensorboard = tensorboard gdown = - gdown>=4.4.0 + gdown==4.6.3 ignite = pytorch-ignite==0.4.11 torchvision = @@ -113,7 +113,7 @@ psutil = cucim = cucim>=23.2.0 openslide = - openslide-python==1.1.2 + openslide-python tifffile = tifffile imagecodecs = @@ -123,7 +123,7 @@ pandas = einops = einops transformers = - transformers<4.22 + transformers<4.22; python_version <= '3.10' mlflow = mlflow matplotlib = @@ -173,6 +173,7 @@ max_line_length = 120 # B028 https://github.com/Project-MONAI/MONAI/issues/5855 # B907 https://github.com/Project-MONAI/MONAI/issues/5868 # B908 https://github.com/Project-MONAI/MONAI/issues/6503 +# B036 https://github.com/Project-MONAI/MONAI/issues/7396 ignore = E203 E501 @@ -186,6 +187,7 @@ ignore = B028 B907 B908 + B036 per_file_ignores = __init__.py: F401, __main__.py: F401 exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,.venv,_version.py diff --git a/tests/test_controlnet_inferers.py b/tests/test_controlnet_inferers.py new file mode 100644 index 0000000000..1f675537dc --- /dev/null +++ b/tests/test_controlnet_inferers.py @@ -0,0 +1,1270 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import ControlNetDiffusionInferer, ControlNetLatentDiffusionInferer +from monai.networks.nets import ( + VQVAE, + AutoencoderKL, + ControlNet, + DiffusionModelUNet, + SPADEAutoencoderKL, + SPADEDiffusionModelUNet, +) +from monai.networks.schedulers import DDIMScheduler, DDPMScheduler +from monai.utils import optional_import + +_, has_scipy = optional_import("scipy") + +CNDM_TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 2, + "in_channels": 1, + "channels": [8], + "attention_levels": [True], + "norm_num_groups": 8, + "num_res_blocks": 1, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (2, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 3, + "in_channels": 1, + "channels": [8], + "attention_levels": [True], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (2, 1, 8, 8, 8), + ], +] +LATENT_CNDM_TEST_CASES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 16, 16), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 3, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 16, 16, 16), + (1, 3, 4, 4, 4), + ], +] +LATENT_CNDM_TEST_CASES_DIFF_SHAPES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + { + "spatial_dims": 3, + "in_channels": 3, + "channels": [8, 8], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 8, + "num_head_channels": 8, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 12, 12, 12), + (1, 3, 8, 8, 8), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + { + "spatial_dims": 2, + "in_channels": 3, + "channels": [4, 4], + "attention_levels": [False, False], + "num_res_blocks": 1, + "norm_num_groups": 4, + "num_head_channels": 4, + "conditioning_embedding_num_channels": [16], + "conditioning_embedding_in_channels": 1, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], +] + + +class ControlNetTestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(CNDM_TEST_CASES) + def test_call(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer( + inputs=input, noise=noise, diffusion_model=model, controlnet=controlnet, timesteps=timesteps, cn_cond=mask + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(CNDM_TEST_CASES) + def test_sample_intermediates(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + noise = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_ddpm_sampler(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_ddim_sampler(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_sampler_conditioned(self, model_params, controlnet_params, input_shape): + model_params["with_conditioning"] = True + model_params["cross_attention_dim"] = 3 + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + conditioning = torch.randn([input_shape[0], 1, 3]).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_get_likelihood(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet = ControlNet(**controlnet_params) + controlnet.to(device) + controlnet.eval() + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + likelihood, intermediates = inferer.get_likelihood( + inputs=input, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + ) + self.assertEqual(intermediates[0].shape, input.shape) + self.assertEqual(likelihood.shape[0], input.shape[0]) + + @unittest.skipUnless(has_scipy, "Requires scipy library.") + def test_normal_cdf(self): + from scipy.stats import norm + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + x = torch.linspace(-10, 10, 20) + cdf_approx = inferer._approx_standard_normal_cdf(x) + cdf_true = norm.cdf(x) + torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + + @parameterized.expand(CNDM_TEST_CASES) + def test_sampler_conditioned_concat(self, model_params, controlnet_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet = ControlNet(**controlnet_params) + controlnet.to(device) + controlnet.eval() + noise = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(len(intermediates), 10) + + +class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_prediction_shape( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + seg=input_seg, + noise=noise, + timesteps=timesteps, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + controlnet=controlnet, + cn_cond=mask, + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_sample_shape( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_sample_intermediates( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + controlnet=controlnet, + cn_cond=mask, + ) + else: + sample, intermediates = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + controlnet=controlnet, + cn_cond=mask, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, input_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_get_likelihoods( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, latent_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_resample_likelihoods( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + resample_latent_likelihoods=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + resample_latent_likelihoods=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_prediction_shape_conditioned_concat( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + controlnet=controlnet, + cn_cond=mask, + timesteps=timesteps, + condition=conditioning, + mode="concat", + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + controlnet=controlnet, + cn_cond=mask, + timesteps=timesteps, + condition=conditioning, + mode="concat", + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES) + def test_sample_shape_conditioned_concat( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES) + def test_sample_shape_different_latents( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] + inferer = ControlNetLatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, + ) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + noise=noise, + timesteps=timesteps, + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + controlnet=controlnet, + cn_cond=mask, + timesteps=timesteps, + ) + self.assertEqual(prediction.shape, latent_shape) + + def test_incompatible_spade_setup(self): + stage_1 = SPADEAutoencoderKL( + spatial_dims=2, + label_nc=6, + in_channels=1, + out_channels=1, + channels=(4, 4), + latent_channels=3, + attention_levels=[False, False], + num_res_blocks=1, + with_encoder_nonlocal_attn=False, + with_decoder_nonlocal_attn=False, + norm_num_groups=4, + ) + stage_2 = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=3, + out_channels=3, + channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + ) + controlnet = ControlNet( + spatial_dims=2, + in_channels=1, + channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + conditioning_embedding_num_channels=[16], + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + noise = torch.randn((1, 3, 4, 4)).to(device) + mask = torch.randn((1, 1, 4, 4)).to(device) + input_seg = torch.randn((1, 3, 8, 8)).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + with self.assertRaises(ValueError): + _ = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + seg=input_seg, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py new file mode 100644 index 0000000000..ecd4855385 --- /dev/null +++ b/tests/test_diffusion_inferer.py @@ -0,0 +1,226 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import DiffusionInferer +from monai.networks.nets import DiffusionModelUNet +from monai.networks.schedulers import DDIMScheduler, DDPMScheduler +from monai.utils import optional_import + +_, has_scipy = optional_import("scipy") + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8, 8), + ], +] + + +class TestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_call(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer(inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(TEST_CASES) + def test_sample_intermediates(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_ddpm_sampler(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_ddim_sampler(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_sampler_conditioned(self, model_params, input_shape): + model_params["with_conditioning"] = True + model_params["cross_attention_dim"] = 3 + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + conditioning = torch.randn([input_shape[0], 1, 3]).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_get_likelihood(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + likelihood, intermediates = inferer.get_likelihood( + inputs=input, diffusion_model=model, scheduler=scheduler, save_intermediates=True + ) + self.assertEqual(intermediates[0].shape, input.shape) + self.assertEqual(likelihood.shape[0], input.shape[0]) + + @unittest.skipUnless(has_scipy, "Requires scipy library.") + def test_normal_cdf(self): + from scipy.stats import norm + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + + x = torch.linspace(-10, 10, 20) + cdf_approx = inferer._approx_standard_normal_cdf(x) + cdf_true = norm.cdf(x) + torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + + @parameterized.expand(TEST_CASES) + def test_sampler_conditioned_concat(self, model_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(TEST_CASES) + def test_call_conditioned_concat(self, model_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + input = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer( + inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps, condition=conditioning, mode="concat" + ) + self.assertEqual(sample.shape, input_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_flexible_unet.py b/tests/test_flexible_unet.py index 1218ce6e85..1d831f0976 100644 --- a/tests/test_flexible_unet.py +++ b/tests/test_flexible_unet.py @@ -39,7 +39,7 @@ class DummyEncoder(BaseEncoder): def get_encoder_parameters(cls): basic_dict = {"spatial_dims": 2, "in_channels": 3, "pretrained": False} param_dict_list = [basic_dict] - for key in basic_dict: + for key in basic_dict.keys(): cur_dict = basic_dict.copy() del cur_dict[key] param_dict_list.append(cur_dict) diff --git a/tests/test_invertd.py b/tests/test_invertd.py index cd2e91257a..2e6ee35981 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -112,15 +112,15 @@ def test_invert(self): self.assertTupleEqual(i.shape[1:], (101, 100, 107)) # check the case that different items use different interpolation mode to invert transforms - d = item["image_inverted1"] + j = item["image_inverted1"] # if the interpolation mode is nearest, accumulated diff should be smaller than 1 - self.assertLess(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) - self.assertTupleEqual(d.shape, (1, 101, 100, 107)) + self.assertLess(torch.sum(j.to(torch.float) - j.to(torch.uint8).to(torch.float)).item(), 1.0) + self.assertTupleEqual(j.shape, (1, 101, 100, 107)) - d = item["label_inverted1"] + k = item["label_inverted1"] # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 - self.assertGreater(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) - self.assertTupleEqual(d.shape, (1, 101, 100, 107)) + self.assertGreater(torch.sum(k.to(torch.float) - k.to(torch.uint8).to(torch.float)).item(), 10000.0) + self.assertTupleEqual(k.shape, (1, 101, 100, 107)) # check labels match reverted = item["label_inverted"].detach().cpu().numpy().astype(np.int32) diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py new file mode 100644 index 0000000000..4ab803bb6f --- /dev/null +++ b/tests/test_latent_diffusion_inferer.py @@ -0,0 +1,796 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import LatentDiffusionInferer +from monai.networks.nets import VQVAE, AutoencoderKL, DiffusionModelUNet, SPADEAutoencoderKL, SPADEDiffusionModelUNet +from monai.networks.schedulers import DDPMScheduler + +TEST_CASES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 16, 16), + (1, 3, 4, 4), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 16, 16, 16), + (1, 3, 4, 4, 4), + ], + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], +] +TEST_CASES_DIFF_SHAPES = [ + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 12, 12), + (1, 3, 8, 8), + ], + [ + "VQVAE", + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [4, 4], + "num_res_layers": 1, + "num_res_channels": [4, 4], + "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)), + "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + "num_embeddings": 16, + "embedding_dim": 3, + }, + "DiffusionModelUNet", + { + "spatial_dims": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [8, 8], + "norm_num_groups": 8, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (1, 1, 12, 12, 12), + (1, 3, 8, 8, 8), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "DiffusionModelUNet", + { + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "AutoencoderKL", + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], + [ + "SPADEAutoencoderKL", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4), + "latent_channels": 3, + "attention_levels": [False, False], + "num_res_blocks": 1, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + "norm_num_groups": 4, + }, + "SPADEDiffusionModelUNet", + { + "spatial_dims": 2, + "label_nc": 3, + "in_channels": 3, + "out_channels": 3, + "channels": [4, 4], + "norm_num_groups": 4, + "attention_levels": [False, False], + "num_res_blocks": 1, + "num_head_channels": 4, + }, + (1, 1, 8, 8), + (1, 3, 4, 4), + ], +] + + +class TestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_prediction_shape( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + seg=input_seg, + noise=noise, + timesteps=timesteps, + ) + else: + prediction = inferer( + inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_sample_shape( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(TEST_CASES) + def test_sample_intermediates( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + save_intermediates=True, + intermediate_steps=1, + ) + else: + sample, intermediates = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, input_shape) + + @parameterized.expand(TEST_CASES) + def test_get_likelihoods( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_resample_likelihoods( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + resample_latent_likelihoods=True, + seg=input_seg, + ) + else: + sample, intermediates = inferer.get_likelihood( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + save_intermediates=True, + resample_latent_likelihoods=True, + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) + + @parameterized.expand(TEST_CASES) + def test_prediction_shape_conditioned_concat( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + condition=conditioning, + mode="concat", + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + condition=conditioning, + mode="concat", + ) + self.assertEqual(prediction.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_sample_shape_conditioned_concat( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + stage_2_params = stage_2_params.copy() + n_concat_channel = 3 + stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + conditioning_shape = list(latent_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + seg=input_seg, + ) + else: + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(TEST_CASES_DIFF_SHAPES) + def test_sample_shape_different_latents( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] + inferer = LatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, + ) + scheduler.set_timesteps(num_inference_steps=10) + + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + + if dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction = inferer( + inputs=input, + autoencoder_model=stage_1, + diffusion_model=stage_2, + noise=noise, + timesteps=timesteps, + seg=input_seg, + ) + else: + prediction = inferer( + inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps + ) + self.assertEqual(prediction.shape, latent_shape) + + def test_incompatible_spade_setup(self): + stage_1 = SPADEAutoencoderKL( + spatial_dims=2, + label_nc=6, + in_channels=1, + out_channels=1, + channels=(4, 4), + latent_channels=3, + attention_levels=[False, False], + num_res_blocks=1, + with_encoder_nonlocal_attn=False, + with_decoder_nonlocal_attn=False, + norm_num_groups=4, + ) + stage_2 = SPADEDiffusionModelUNet( + spatial_dims=2, + label_nc=3, + in_channels=3, + out_channels=3, + channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + noise = torch.randn((1, 3, 4, 4)).to(device) + input_seg = torch.randn((1, 3, 8, 8)).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + with self.assertRaises(ValueError): + _ = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_ordering.py b/tests/test_ordering.py index 0c52dba5e5..e6b235e179 100644 --- a/tests/test_ordering.py +++ b/tests/test_ordering.py @@ -182,24 +182,6 @@ ], ] -TEST_2D_RANDOM = [ - [ - { - "ordering_type": OrderingType.RANDOM, - "spatial_dims": 2, - "dimensions": (1, 2, 2), - "reflected_spatial_dims": (True, False), - "transpositions_axes": ((1, 0),), - "rot90_axes": ((0, 1),), - "transformation_order": ( - OrderingTransformations.TRANSPOSE.value, - OrderingTransformations.ROTATE_90.value, - OrderingTransformations.REFLECT.value, - ), - }, - [[0, 1, 2, 3], [0, 1, 3, 2]], - ] -] TEST_3D = [ [ @@ -291,17 +273,6 @@ def test_ordering_transformation_failure(self, input_param): with self.assertRaises(ValueError): Ordering(**input_param) - @parameterized.expand(TEST_2D_RANDOM) - def test_random(self, input_param, not_in_expected_sequence_ordering): - ordering = Ordering(**input_param) - - not_in = [ - np.array_equal(sequence, ordering.get_sequence_ordering(), equal_nan=True) - for sequence in not_in_expected_sequence_ordering - ] - - self.assertFalse(np.any(not_in)) - @parameterized.expand(TEST_REVERT) def test_revert(self, input_param): sequence = np.random.randint(0, 100, size=input_param["dimensions"]).flatten() diff --git a/test_spade_autoencoderkl.py b/tests/test_spade_autoencoderkl.py similarity index 100% rename from test_spade_autoencoderkl.py rename to tests/test_spade_autoencoderkl.py diff --git a/test_spade_diffusion_model_unet.py b/tests/test_spade_diffusion_model_unet.py similarity index 92% rename from test_spade_diffusion_model_unet.py rename to tests/test_spade_diffusion_model_unet.py index c8a2103cf6..113e58ed89 100644 --- a/test_spade_diffusion_model_unet.py +++ b/tests/test_spade_diffusion_model_unet.py @@ -26,7 +26,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "label_nc": 3, @@ -38,7 +38,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": (1, 1, 2), - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "label_nc": 3, @@ -50,7 +50,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "resblock_updown": True, @@ -63,7 +63,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, @@ -76,7 +76,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, @@ -90,7 +90,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -103,7 +103,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, True, True), "num_head_channels": (0, 2, 4), "norm_num_groups": 8, @@ -119,7 +119,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "label_nc": 3, @@ -132,7 +132,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "label_nc": 3, @@ -144,7 +144,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, False), "norm_num_groups": 8, "resblock_updown": True, @@ -157,7 +157,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, @@ -170,7 +170,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, @@ -184,7 +184,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -197,7 +197,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": (0, 0, 4), "norm_num_groups": 8, @@ -213,7 +213,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -229,7 +229,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -246,7 +246,7 @@ "in_channels": 1, "out_channels": 1, "num_res_blocks": 1, - "num_channels": (8, 8, 8), + "channels": (8, 8, 8), "attention_levels": (False, False, True), "num_head_channels": 4, "norm_num_groups": 8, @@ -279,7 +279,7 @@ def test_timestep_with_wrong_shape(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -296,7 +296,7 @@ def test_label_with_wrong_shape(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -313,7 +313,7 @@ def test_shape_with_different_in_channel_out_channel(self): in_channels=in_channels, out_channels=out_channels, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -331,7 +331,7 @@ def test_model_channels_not_multiple_of_norm_num_group(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 12), + channels=(8, 8, 12), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -344,13 +344,13 @@ def test_attention_levels_with_different_length_num_head_channels(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), num_head_channels=(0, 2), norm_num_groups=8, ) - def test_num_res_blocks_with_different_length_num_channels(self): + def test_num_res_blocks_with_different_length_channels(self): with self.assertRaises(ValueError): SPADEDiffusionModelUNet( spatial_dims=2, @@ -358,7 +358,7 @@ def test_num_res_blocks_with_different_length_num_channels(self): in_channels=1, out_channels=1, num_res_blocks=(1, 1), - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, False), norm_num_groups=8, ) @@ -370,7 +370,7 @@ def test_shape_conditioned_models(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), with_conditioning=True, transformer_num_layers=1, @@ -395,7 +395,7 @@ def test_with_conditioning_cross_attention_dim_none(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), with_conditioning=True, transformer_num_layers=1, @@ -410,7 +410,7 @@ def test_context_with_conditioning_none(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), with_conditioning=False, transformer_num_layers=1, @@ -433,7 +433,7 @@ def test_shape_conditioned_models_class_conditioning(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), norm_num_groups=8, num_head_channels=8, @@ -455,7 +455,7 @@ def test_conditioned_models_no_class_labels(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), norm_num_groups=8, num_head_channels=8, @@ -469,7 +469,7 @@ def test_conditioned_models_no_class_labels(self): seg=torch.rand((1, 3, 16, 32)), ) - def test_model_num_channels_not_same_size_of_attention_levels(self): + def test_model_channels_not_same_size_of_attention_levels(self): with self.assertRaises(ValueError): SPADEDiffusionModelUNet( spatial_dims=2, @@ -477,7 +477,7 @@ def test_model_num_channels_not_same_size_of_attention_levels(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False), norm_num_groups=8, num_head_channels=8, @@ -518,7 +518,7 @@ def test_shape_with_different_in_channel_out_channel(self): in_channels=in_channels, out_channels=out_channels, num_res_blocks=1, - num_channels=(8, 8, 8), + channels=(8, 8, 8), attention_levels=(False, False, True), norm_num_groups=4, ) @@ -537,7 +537,7 @@ def test_shape_conditioned_models(self): in_channels=1, out_channels=1, num_res_blocks=1, - num_channels=(16, 16, 16), + channels=(16, 16, 16), attention_levels=(False, False, True), norm_num_groups=16, with_conditioning=True, diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py new file mode 100644 index 0000000000..1a511d287b --- /dev/null +++ b/tests/test_vqvaetransformer_inferer.py @@ -0,0 +1,284 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import VQVAETransformerInferer +from monai.networks.nets import VQVAE, DecoderOnlyTransformer +from monai.utils.ordering import Ordering, OrderingType + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (8, 8), + "num_res_channels": (8, 8), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "num_embeddings": 16, + "embedding_dim": 8, + }, + { + "num_tokens": 16 + 1, + "max_seq_len": 4, + "attn_layers_dim": 4, + "attn_layers_depth": 2, + "attn_layers_heads": 1, + "with_cross_attention": False, + }, + {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 2, "dimensions": (2, 2, 2)}, + (2, 1, 8, 8), + (2, 4, 17), + (2, 2, 2), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": (8, 8), + "num_res_channels": (8, 8), + "downsample_parameters": ((2, 4, 1, 1),) * 2, + "upsample_parameters": ((2, 4, 1, 1, 0),) * 2, + "num_res_layers": 1, + "num_embeddings": 16, + "embedding_dim": 8, + }, + { + "num_tokens": 16 + 1, + "max_seq_len": 8, + "attn_layers_dim": 4, + "attn_layers_depth": 2, + "attn_layers_heads": 1, + "with_cross_attention": False, + }, + {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 3, "dimensions": (2, 2, 2, 2)}, + (2, 1, 8, 8, 8), + (2, 8, 17), + (2, 2, 2, 2), + ], +] + + +class TestVQVAETransformerInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_prediction_shape( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) + self.assertEqual(prediction.shape, logits_shape) + + @parameterized.expand(TEST_CASES) + def test_prediction_shape_shorter_sequence( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + max_seq_len = 3 + stage_2_params_shorter = dict(stage_2_params) + stage_2_params_shorter["max_seq_len"] = max_seq_len + stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering) + cropped_logits_shape = (logits_shape[0], max_seq_len, logits_shape[2]) + self.assertEqual(prediction.shape, cropped_logits_shape) + + def test_sample(self): + stage_1 = VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(8, 8), + num_res_channels=(8, 8), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + num_res_layers=1, + num_embeddings=16, + embedding_dim=8, + ) + stage_2 = DecoderOnlyTransformer( + num_tokens=16 + 1, + max_seq_len=4, + attn_layers_dim=4, + attn_layers_depth=2, + attn_layers_heads=1, + with_cross_attention=False, + ) + ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2)) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + inferer = VQVAETransformerInferer() + + starting_token = 16 # from stage_1 num_embeddings + + sample = inferer.sample( + latent_spatial_dim=(2, 2), + starting_tokens=starting_token * torch.ones((2, 1), device=device), + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + ) + self.assertEqual(sample.shape, (2, 1, 8, 8)) + + def test_sample_shorter_sequence(self): + stage_1 = VQVAE( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(8, 8), + num_res_channels=(8, 8), + downsample_parameters=((2, 4, 1, 1),) * 2, + upsample_parameters=((2, 4, 1, 1, 0),) * 2, + num_res_layers=1, + num_embeddings=16, + embedding_dim=8, + ) + stage_2 = DecoderOnlyTransformer( + num_tokens=16 + 1, + max_seq_len=2, + attn_layers_dim=4, + attn_layers_depth=2, + attn_layers_heads=1, + with_cross_attention=False, + ) + ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2)) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + inferer = VQVAETransformerInferer() + + starting_token = 16 # from stage_1 num_embeddings + + sample = inferer.sample( + latent_spatial_dim=(2, 2), + starting_tokens=starting_token * torch.ones((2, 1), device=device), + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + ) + self.assertEqual(sample.shape, (2, 1, 8, 8)) + + @parameterized.expand(TEST_CASES) + def test_get_likelihood( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood( + inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering + ) + self.assertEqual(likelihood.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_get_likelihood_shorter_sequence( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + max_seq_len = 3 + stage_2_params_shorter = dict(stage_2_params) + stage_2_params_shorter["max_seq_len"] = max_seq_len + stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood( + inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering + ) + self.assertEqual(likelihood.shape, latent_shape) + + @parameterized.expand(TEST_CASES) + def test_get_likelihood_resampling( + self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape + ): + stage_1 = VQVAE(**stage_1_params) + stage_2 = DecoderOnlyTransformer(**stage_2_params) + ordering = Ordering(**ordering_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + input = torch.randn(input_shape).to(device) + + inferer = VQVAETransformerInferer() + likelihood = inferer.get_likelihood( + inputs=input, + vqvae_model=stage_1, + transformer_model=stage_2, + ordering=ordering, + resample_latent_likelihoods=True, + resample_interpolation_mode="nearest", + ) + self.assertEqual(likelihood.shape, input_shape) + + +if __name__ == "__main__": + unittest.main() From bc69d8b79dd6cf7bd7d9310c2a10f1b947808339 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 18 Jan 2024 15:43:36 +0000 Subject: [PATCH 38/67] Adds prepare batch engine Signed-off-by: Mark Graham --- monai/engines/prepare_batch.py | 93 ++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 monai/engines/prepare_batch.py diff --git a/monai/engines/prepare_batch.py b/monai/engines/prepare_batch.py new file mode 100644 index 0000000000..5b68d2b3a0 --- /dev/null +++ b/monai/engines/prepare_batch.py @@ -0,0 +1,93 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, Mapping + +import torch + +from monai.engines import PrepareBatch, default_prepare_batch +from monai.networks.schedulers import Scheduler + + +class DiffusionPrepareBatch(PrepareBatch): + """ + This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. + + Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and + return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise". + This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided. + + If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition + field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". + + """ + + def __init__(self, num_train_timesteps: int, condition_name: str | None = None) -> None: + self.condition_name = condition_name + self.num_train_timesteps = num_train_timesteps + + def get_noise(self, images: torch.Tensor) -> torch.Tensor: + """Returns the noise tensor for input tensor `images`, override this for different noise distributions.""" + return torch.randn_like(images) + + def get_timesteps(self, images: torch.Tensor) -> torch.Tensor: + """Get a timestep, by default this is a random integer between 0 and `self.num_train_timesteps`.""" + return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long() + + def get_target(self, images: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + """Return the target for the loss function, this is the `noise` value by default.""" + return noise + + def __call__( + self, + batchdata: dict[str, torch.Tensor], + device: str | torch.device | None = None, + non_blocking: bool = False, + **kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor, tuple, dict[str, torch.Tensor]]: + images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs) + noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs) + timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs) + + target = self.get_target(images, noise, timesteps).to(device, non_blocking=non_blocking, **kwargs) + infer_kwargs = {"noise": noise, "timesteps": timesteps} + + if self.condition_name is not None and isinstance(batchdata, Mapping): + infer_kwargs["conditioning"] = batchdata[self.condition_name].to( + device, non_blocking=non_blocking, **kwargs + ) + + # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value + return images, target, (), infer_kwargs + + +class VPredictionPrepareBatch(DiffusionPrepareBatch): + """ + This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. + + Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and + from this compute the velocity using the provided scheduler. This value is used as the target in place of the + noise field itself although the noise is field is in the kwargs under the key "noise". This assumes the inferer + being used in conjunction with this class expects a "noise" parameter to be provided. + + If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition + field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". + + """ + + def __init__(self, scheduler: Scheduler, num_train_timesteps: int, condition_name: str | None = None) -> None: + super().__init__(num_train_timesteps=num_train_timesteps, condition_name=condition_name) + self.scheduler = scheduler + + def get_target(self, images, noise, timesteps): + return self.scheduler.get_velocity(images, noise, timesteps) From 2a8954b81cce3a794d15f0c0d614994547b1ca62 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 19 Jan 2024 14:26:17 +0000 Subject: [PATCH 39/67] Test prepare batch Signed-off-by: Mark Graham --- monai/engines/__init__.py | 2 + monai/engines/prepare_batch.py | 93 --------------------------- monai/engines/utils.py | 14 ++-- tests/test_prepare_batch_diffusion.py | 54 ++++++++++------ 4 files changed, 43 insertions(+), 120 deletions(-) delete mode 100644 monai/engines/prepare_batch.py diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index be65c39f70..93cc40e292 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -14,10 +14,12 @@ from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator from .trainer import AdversarialTrainer, GanTrainer, SupervisedTrainer, Trainer from .utils import ( + DiffusionPrepareBatch, IterationEvents, PrepareBatch, PrepareBatchDefault, PrepareBatchExtraInput, + VPredictionPrepareBatch, default_make_latent, default_metric_cmp_fn, default_prepare_batch, diff --git a/monai/engines/prepare_batch.py b/monai/engines/prepare_batch.py deleted file mode 100644 index 5b68d2b3a0..0000000000 --- a/monai/engines/prepare_batch.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import Any, Mapping - -import torch - -from monai.engines import PrepareBatch, default_prepare_batch -from monai.networks.schedulers import Scheduler - - -class DiffusionPrepareBatch(PrepareBatch): - """ - This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. - - Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and - return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise". - This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided. - - If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition - field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". - - """ - - def __init__(self, num_train_timesteps: int, condition_name: str | None = None) -> None: - self.condition_name = condition_name - self.num_train_timesteps = num_train_timesteps - - def get_noise(self, images: torch.Tensor) -> torch.Tensor: - """Returns the noise tensor for input tensor `images`, override this for different noise distributions.""" - return torch.randn_like(images) - - def get_timesteps(self, images: torch.Tensor) -> torch.Tensor: - """Get a timestep, by default this is a random integer between 0 and `self.num_train_timesteps`.""" - return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long() - - def get_target(self, images: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: - """Return the target for the loss function, this is the `noise` value by default.""" - return noise - - def __call__( - self, - batchdata: dict[str, torch.Tensor], - device: str | torch.device | None = None, - non_blocking: bool = False, - **kwargs: Any, - ) -> tuple[torch.Tensor, torch.Tensor, tuple, dict[str, torch.Tensor]]: - images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs) - noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs) - timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs) - - target = self.get_target(images, noise, timesteps).to(device, non_blocking=non_blocking, **kwargs) - infer_kwargs = {"noise": noise, "timesteps": timesteps} - - if self.condition_name is not None and isinstance(batchdata, Mapping): - infer_kwargs["conditioning"] = batchdata[self.condition_name].to( - device, non_blocking=non_blocking, **kwargs - ) - - # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value - return images, target, (), infer_kwargs - - -class VPredictionPrepareBatch(DiffusionPrepareBatch): - """ - This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. - - Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and - from this compute the velocity using the provided scheduler. This value is used as the target in place of the - noise field itself although the noise is field is in the kwargs under the key "noise". This assumes the inferer - being used in conjunction with this class expects a "noise" parameter to be provided. - - If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition - field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". - - """ - - def __init__(self, scheduler: Scheduler, num_train_timesteps: int, condition_name: str | None = None) -> None: - super().__init__(num_train_timesteps=num_train_timesteps, condition_name=condition_name) - self.scheduler = scheduler - - def get_target(self, images, noise, timesteps): - return self.scheduler.get_velocity(images, noise, timesteps) diff --git a/monai/engines/utils.py b/monai/engines/utils.py index a76645ce36..5339d6965a 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -13,7 +13,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Mapping, cast import torch import torch.nn as nn @@ -254,7 +254,7 @@ class DiffusionPrepareBatch(PrepareBatch): """ - def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None) -> None: + def __init__(self, num_train_timesteps: int, condition_name: str | None = None) -> None: self.condition_name = condition_name self.num_train_timesteps = num_train_timesteps @@ -272,8 +272,8 @@ def get_target(self, images: torch.Tensor, noise: torch.Tensor, timesteps: torch def __call__( self, - batchdata: Dict[str, torch.Tensor], - device: Optional[Union[str, torch.device]] = None, + batchdata: dict[str, torch.Tensor], + device: str | torch.device | None = None, non_blocking: bool = False, **kwargs: Any, ) -> tuple[torch.Tensor, torch.Tensor, tuple, dict]: @@ -285,9 +285,7 @@ def __call__( infer_kwargs = {"noise": noise, "timesteps": timesteps} if self.condition_name is not None and isinstance(batchdata, Mapping): - infer_kwargs["conditioning"] = batchdata[self.condition_name].to( - device, non_blocking=non_blocking, **kwargs - ) + infer_kwargs["condition"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs) # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value return images, target, (), infer_kwargs @@ -307,7 +305,7 @@ class VPredictionPrepareBatch(DiffusionPrepareBatch): """ - def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_name: Optional[str] = None) -> None: + def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_name: str | None = None) -> None: super().__init__(num_train_timesteps=num_train_timesteps, condition_name=condition_name) self.scheduler = scheduler diff --git a/tests/test_prepare_batch_diffusion.py b/tests/test_prepare_batch_diffusion.py index c72053de99..d969c06368 100644 --- a/tests/test_prepare_batch_diffusion.py +++ b/tests/test_prepare_batch_diffusion.py @@ -15,10 +15,12 @@ import torch from parameterized import parameterized + +from monai.engines import SupervisedEvaluator from monai.engines.utils import DiffusionPrepareBatch +from monai.inferers import DiffusionInferer from monai.networks.nets import DiffusionModelUNet -from monai.engines import SupervisedEvaluator -from tests.utils import assert_allclose +from monai.networks.schedulers import DDPMScheduler TEST_CASES = [ [ @@ -50,38 +52,52 @@ ] - class TestPrepareBatchDiffusion(unittest.TestCase): @parameterized.expand(TEST_CASES) - def test_content(self, input_args, image_size): + def test_output_sizes(self, input_args, image_size): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - dataloader = [ - { - "image": torch.randn(image_size).to(device), - } - ] - # set up engine + dataloader = [{"image": torch.randn(image_size).to(device)}] + scheduler = DDPMScheduler(num_train_timesteps=20) + inferer = DiffusionInferer(scheduler=scheduler) network = DiffusionModelUNet(**input_args).to(device) - num_train_timesteps = 10 - scheduler = DDPMScheduler(num_train_timesteps=num_train_timesteps) - inferer = DiffusionInferer() evaluator = SupervisedEvaluator( device=device, val_data_loader=dataloader, epoch_length=1, network=network, + inferer=inferer, non_blocking=True, prepare_batch=DiffusionPrepareBatch(num_train_timesteps=20), decollate=False, ) evaluator.run() output = evaluator.state.output - assert_allclose(output["image"], torch.tensor([1, 2], device=device)) - for k, v in output["pred"].items(): - if isinstance(v, torch.Tensor): - assert_allclose(v, expected_value[k].to(device)) - else: - self.assertEqual(v, expected_value[k]) + # check shapes are the same + self.assertEqual(output["pred"].shape, image_size) + self.assertEqual(output["label"].shape, output["image"].shape) + + @parameterized.expand(TEST_CASES) + def test_conditioning(self, input_args, image_size): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataloader = [{"image": torch.randn(image_size).to(device), "context": torch.randn((2, 4, 3)).to(device)}] + scheduler = DDPMScheduler(num_train_timesteps=20) + inferer = DiffusionInferer(scheduler=scheduler) + network = DiffusionModelUNet(**input_args, with_conditioning=True, cross_attention_dim=3).to(device) + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=dataloader, + epoch_length=1, + network=network, + inferer=inferer, + non_blocking=True, + prepare_batch=DiffusionPrepareBatch(num_train_timesteps=20, condition_name="context"), + decollate=False, + ) + evaluator.run() + output = evaluator.state.output + # check shapes are the same + self.assertEqual(output["pred"].shape, image_size) + self.assertEqual(output["label"].shape, output["image"].shape) if __name__ == "__main__": From e15a50fac03b292fa2721131adc6ffa0644202cd Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 29 Jan 2024 11:43:33 +0000 Subject: [PATCH 40/67] Skip nibabel test if not installed Signed-off-by: Mark Graham --- tests/test_integration_workflows_adversarial.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_integration_workflows_adversarial.py b/tests/test_integration_workflows_adversarial.py index e0ff93bbe3..f323fc9917 100644 --- a/tests/test_integration_workflows_adversarial.py +++ b/tests/test_integration_workflows_adversarial.py @@ -17,7 +17,6 @@ import unittest from glob import glob -import nibabel as nib import numpy as np import torch @@ -28,9 +27,11 @@ from monai.networks.nets import AutoEncoder, Discriminator from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, RandFlipd, ScaleIntensityd from monai.utils import AdversarialKeys as Keys -from monai.utils import CommonKeys, set_determinism +from monai.utils import CommonKeys, optional_import, set_determinism from tests.utils import DistTestCase, TimedCall, skip_if_quick +nib, has_nibabel = optional_import("nibabel") + def run_training_test(root_dir, device="cuda:0"): learning_rate = 2e-4 @@ -139,6 +140,7 @@ def generator_loss(fake_logits): @skip_if_quick +@unittest.skipUnless(has_nibabel, "Requires nibabel library.") class IntegrationWorkflowsAdversarialTrainer(DistTestCase): def setUp(self): set_determinism(seed=0) From f60c4e28dc224060949f7387e2794d2f682390ef Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 29 Jan 2024 12:02:30 +0000 Subject: [PATCH 41/67] Adds prepare batch to min tests Signed-off-by: Mark Graham --- tests/min_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/min_tests.py b/tests/min_tests.py index 8128bb7b84..3a143df84b 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -154,6 +154,7 @@ def run_testsuit(): "test_plot_2d_or_3d_image", "test_png_rw", "test_prepare_batch_default", + "test_prepare_batch_diffusion", "test_prepare_batch_extra_input", "test_prepare_batch_hovernet", "test_rand_grid_patch", From 11c12856313a51e9a8ef53caaf0dc7bf21f488f4 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 5 Dec 2023 13:20:01 +0800 Subject: [PATCH 42/67] update the Python version requirements for transformers (#7275) Part of #7250. ### Description Fix the Python version for transformers smaller than 3.10. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu Signed-off-by: Mark Graham --- docs/requirements.txt | 2 +- requirements-dev.txt | 2 +- tests/test_transchex.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index a9bbc384f8..e5bedf8552 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -21,7 +21,7 @@ sphinxcontrib-serializinghtml sphinx-autodoc-typehints==1.11.1 pandas einops -transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157 +transformers<4.22; python_version <= '3.10' # https://github.com/Project-MONAI/MONAI/issues/5157 mlflow>=1.28.0 clearml>=1.10.0rc0 tensorboardX diff --git a/requirements-dev.txt b/requirements-dev.txt index 6332d5b0a5..cacbefe234 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -33,7 +33,7 @@ tifffile; platform_system == "Linux" or platform_system == "Darwin" pandas requests einops -transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157 +transformers<4.22; python_version <= '3.10' # https://github.com/Project-MONAI/MONAI/issues/5157 mlflow>=1.28.0 clearml>=1.10.0rc0 matplotlib!=3.5.0 diff --git a/tests/test_transchex.py b/tests/test_transchex.py index 9ad847cdaa..8fb1f56715 100644 --- a/tests/test_transchex.py +++ b/tests/test_transchex.py @@ -18,7 +18,7 @@ from monai.networks import eval_mode from monai.networks.nets.transchex import Transchex -from tests.utils import skip_if_quick +from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_quick TEST_CASE_TRANSCHEX = [] for drop_out in [0.4]: @@ -46,6 +46,7 @@ @skip_if_quick +@SkipIfAtLeastPyTorchVersion((1, 10)) class TestTranschex(unittest.TestCase): @parameterized.expand(TEST_CASE_TRANSCHEX) def test_shape(self, input_param, expected_shape): From e15b5707ba77d200ef7e8059791f2896b3d8dbc5 Mon Sep 17 00:00:00 2001 From: Kaibo Tang Date: Tue, 5 Dec 2023 03:46:24 -0500 Subject: [PATCH 43/67] 7263 add diffusion loss (#7272) Fixes #7263. ### Description Add diffusion loss. I also made a [demo notebook](https://github.com/kvttt/deep-atlas/blob/main/diffusion_loss_scale_test.ipynb) to provide some explanations and analyses of diffusion loss. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: kaibo Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- docs/source/losses.rst | 5 ++ monai/losses/__init__.py | 2 +- monai/losses/deform.py | 82 +++++++++++++++++++++++++ tests/test_diffusion_loss.py | 116 +++++++++++++++++++++++++++++++++++ 4 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 tests/test_diffusion_loss.py diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 568c7dfc77..e929e9d605 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -96,6 +96,11 @@ Registration Losses .. autoclass:: BendingEnergyLoss :members: +`DiffusionLoss` +~~~~~~~~~~~~~~~ +.. autoclass:: DiffusionLoss + :members: + `LocalNormalizedCrossCorrelationLoss` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: LocalNormalizedCrossCorrelationLoss diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index d734a9d44d..92898c81ca 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -14,7 +14,7 @@ from .adversarial_loss import PatchAdversarialLoss from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss from .contrastive import ContrastiveLoss -from .deform import BendingEnergyLoss +from .deform import BendingEnergyLoss, DiffusionLoss from .dice import ( Dice, DiceCELoss, diff --git a/monai/losses/deform.py b/monai/losses/deform.py index dd03a8eb3d..129abeedd2 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -116,3 +116,85 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return energy + + +class DiffusionLoss(_Loss): + """ + Calculate the diffusion based on first-order differentiation of pred using central finite difference. + For the original paper, please refer to + VoxelMorph: A Learning Framework for Deformable Medical Image Registration, + Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca + IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231. + + Adapted from: + VoxelMorph (https://github.com/voxelmorph/voxelmorph) + """ + + def __init__(self, normalize: bool = False, reduction: LossReduction | str = LossReduction.MEAN) -> None: + """ + Args: + normalize: + Whether to divide out spatial sizes in order to make the computation roughly + invariant to image scale (i.e. vector field sampling resolution). Defaults to False. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + """ + super().__init__(reduction=LossReduction(reduction).value) + self.normalize = normalize + + def forward(self, pred: torch.Tensor) -> torch.Tensor: + """ + Args: + pred: + Predicted dense displacement field (DDF) with shape BCH[WD], + where C is the number of spatial dimensions. + Note that diffusion loss can only be calculated + when the sizes of the DDF along all spatial dimensions are greater than 2. + + Raises: + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + ValueError: When ``pred`` is not 3-d, 4-d or 5-d. + ValueError: When any spatial dimension of ``pred`` has size less than or equal to 2. + ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions. + + """ + if pred.ndim not in [3, 4, 5]: + raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") + for i in range(pred.ndim - 2): + if pred.shape[-i - 1] <= 2: + raise ValueError(f"All spatial dimensions must be > 2, got spatial dimensions {pred.shape[2:]}") + if pred.shape[1] != pred.ndim - 2: + raise ValueError( + f"Number of vector components, i.e. number of channels of the input DDF, {pred.shape[1]}, " + f"does not match number of spatial dimensions, {pred.ndim - 2}" + ) + + # first order gradient + first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)] + + # spatial dimensions in a shape suited for broadcasting below + if self.normalize: + spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,)) + + diffusion = torch.tensor(0) + for dim_1, g in enumerate(first_order_gradient): + dim_1 += 2 + if self.normalize: + # We divide the partial derivative for each vector component at each voxel by the spatial size + # corresponding to that component relative to the spatial size of the vector component with respect + # to which the partial derivative is taken. + g *= pred.shape[dim_1] / spatial_dims + diffusion = diffusion + g**2 + + if self.reduction == LossReduction.MEAN.value: + diffusion = torch.mean(diffusion) # the batch and channel average + elif self.reduction == LossReduction.SUM.value: + diffusion = torch.sum(diffusion) # sum over the batch and channel dims + elif self.reduction != LossReduction.NONE.value: + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + + return diffusion diff --git a/tests/test_diffusion_loss.py b/tests/test_diffusion_loss.py new file mode 100644 index 0000000000..05dfab95fb --- /dev/null +++ b/tests/test_diffusion_loss.py @@ -0,0 +1,116 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses.deform import DiffusionLoss + +device = "cuda" if torch.cuda.is_available() else "cpu" + +TEST_CASES = [ + # all first partials are zero, so the diffusion loss is also zero + [{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0], + # all first partials are one, so the diffusion loss is also one + [{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 1.0], + # before expansion, the first partials are 2, 4, 6, so the diffusion loss is (2^2 + 4^2 + 6^2) / 3 = 18.67 + [ + {"normalize": False}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 56.0 / 3.0, + ], + # same as the previous case + [ + {"normalize": False}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, + 56.0 / 3.0, + ], + # same as the previous case + [{"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0], + # we have shown in the demo notebook that + # diffusion loss is scale-invariant when the all axes have the same resolution + [ + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 56.0 / 3.0, + ], + [ + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, + 56.0 / 3.0, + ], + [{"normalize": True}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0], + # for the following case, consider the following 2D matrix: + # tensor([[[[0, 1, 2], + # [1, 2, 3], + # [2, 3, 4], + # [3, 4, 5], + # [4, 5, 6]], + # [[0, 1, 2], + # [1, 2, 3], + # [2, 3, 4], + # [3, 4, 5], + # [4, 5, 6]]]]) + # the first partials wrt x are all ones, and so are the first partials wrt y + # the diffusion loss, when normalization is not applied, is 1^2 + 1^2 = 2 + [{"normalize": False}, {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, 2.0], + # consider the same matrix, this time with normalization applied, using the same notation as in the demo notebook, + # the coefficients to be divided out are (1, 5/3) for partials wrt x and (3/5, 1) for partials wrt y + # the diffusion loss is then (1/1)^2 + (1/(5/3))^2 + (1/(3/5))^2 + (1/1)^2 = (1 + 9/25 + 25/9 + 1) / 2 = 2.5689 + [ + {"normalize": True}, + {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, + (1.0 + 9.0 / 25.0 + 25.0 / 9.0 + 1.0) / 2.0, + ], +] + + +class TestDiffusionLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = DiffusionLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) + + def test_ill_shape(self): + loss = DiffusionLoss() + # not in 3-d, 4-d, 5-d + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): + loss.forward(torch.ones((1, 3), device=device)) + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): + loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device)) + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): + loss.forward(torch.ones((1, 3, 2, 5, 5), device=device)) + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): + loss.forward(torch.ones((1, 3, 5, 2, 5))) + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): + loss.forward(torch.ones((1, 3, 5, 5, 2))) + + # number of vector components unequal to number of spatial dims + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + + def test_ill_opts(self): + pred = torch.rand(1, 3, 5, 5, 5).to(device=device) + with self.assertRaisesRegex(ValueError, ""): + DiffusionLoss(reduction="unknown")(pred) + with self.assertRaisesRegex(ValueError, ""): + DiffusionLoss(reduction=None)(pred) + + +if __name__ == "__main__": + unittest.main() From fdbc6110302b3014c191c60802864f40a87ff9e2 Mon Sep 17 00:00:00 2001 From: vgrau98 <35843843+vgrau98@users.noreply.github.com> Date: Thu, 18 Jan 2024 17:21:27 +0100 Subject: [PATCH 44/67] [Attention block] relative positional embedding (#7346) Fixes #7356 ### Description Add relative positinoal embedding in attention block as described in https://arxiv.org/pdf/2112.01526.pdf Largely inspired by https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py Can be useful for #6357 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: vgrau98 Signed-off-by: vgrau98 <35843843+vgrau98@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- docs/source/networks.rst | 6 + monai/networks/blocks/attention_utils.py | 128 +++++++++++++++++++++ monai/networks/blocks/rel_pos_embedding.py | 56 +++++++++ monai/networks/blocks/selfattention.py | 33 +++++- monai/networks/layers/factories.py | 13 ++- monai/networks/layers/utils.py | 15 ++- tests/test_selfattention.py | 21 +++- 7 files changed, 262 insertions(+), 10 deletions(-) create mode 100644 monai/networks/blocks/attention_utils.py create mode 100644 monai/networks/blocks/rel_pos_embedding.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index f9375f1e97..556bf12d50 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -248,6 +248,12 @@ Blocks .. autoclass:: monai.apps.reconstruction.networks.blocks.varnetblock.VarNetBlock :members: +`Attention utilities` +~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: monai.networks.blocks.attention_utils +.. autofunction:: monai.networks.blocks.attention_utils.get_rel_pos +.. autofunction:: monai.networks.blocks.attention_utils.add_decomposed_rel_pos + N-Dim Fourier Transform ~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: monai.networks.blocks.fft_utils_t diff --git a/monai/networks/blocks/attention_utils.py b/monai/networks/blocks/attention_utils.py new file mode 100644 index 0000000000..8c9002a16e --- /dev/null +++ b/monai/networks/blocks/attention_utils.py @@ -0,0 +1,128 @@ +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + rel_pos_resized: torch.Tensor = torch.Tensor() + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear" + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: Tuple, k_size: Tuple +) -> torch.Tensor: + r""" + Calculate decomposed Relative Positional Embeddings from mvitv2 implementation: + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Only 2D and 3D are supported. + + Encoding the relative position of tokens in the attention matrix: tokens spaced a distance + `d` apart will have the same embedding value (unlike absolute positional embedding). + + .. math:: + Attn_{logits}(Q, K) = (QK^{T} + E_{rel})*scale + + where + + .. math:: + E_{ij}^{(rel)} = Q_{i}.R_{p(i), p(j)} + + with :math:`R_{p(i), p(j)} \in R^{dim}` and :math:`p(i), p(j)`, + respectively spatial positions of element :math:`i` and :math:`j` + + When using "decomposed" relative positional embedding, positional embedding is defined ("decomposed") as follow: + + .. math:: + R_{p(i), p(j)} = R^{d1}_{d1(i), d1(j)} + ... + R^{dn}_{dn(i), dn(j)} + + with :math:`n = 1...dim` + + Decomposed relative positional embedding reduces the complexity from :math:`\mathcal{O}(d1*...*dn)` to + :math:`\mathcal{O}(d1+...+dn)` compared with classical relative positional embedding. + + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C). + rel_pos_lst (ParameterList): relative position embeddings for each axis: rel_pos_lst[n] for nth axis. + q_size (Tuple): spatial sequence size of query q with (q_dim_1, ..., q_dim_n). + k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n). + + Returns: + attn (Tensor): attention logits with added relative positional embeddings. + """ + rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0]) + rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1]) + + batch, _, dim = q.shape + + if len(rel_pos_lst) == 2: + q_h, q_w = q_size[:2] + k_h, k_w = k_size[:2] + r_q = q.reshape(batch, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rw) + + attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( + batch, q_h * q_w, k_h * k_w + ) + elif len(rel_pos_lst) == 3: + q_h, q_w, q_d = q_size[:3] + k_h, k_w, k_d = k_size[:3] + + rd = get_rel_pos(q_d, k_d, rel_pos_lst[2]) + + r_q = q.reshape(batch, q_h, q_w, q_d, dim) + rel_h = torch.einsum("bhwdc,hkc->bhwdk", r_q, rh) + rel_w = torch.einsum("bhwdc,wkc->bhwdk", r_q, rw) + rel_d = torch.einsum("bhwdc,wkc->bhwdk", r_q, rd) + + attn = ( + attn.view(batch, q_h, q_w, q_d, k_h, k_w, k_d) + + rel_h[:, :, :, :, None, None] + + rel_w[:, :, :, None, :, None] + + rel_d[:, :, :, None, None, :] + ).view(batch, q_h * q_w * q_d, k_h * k_w * k_d) + + return attn diff --git a/monai/networks/blocks/rel_pos_embedding.py b/monai/networks/blocks/rel_pos_embedding.py new file mode 100644 index 0000000000..e53e5841b0 --- /dev/null +++ b/monai/networks/blocks/rel_pos_embedding.py @@ -0,0 +1,56 @@ +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Iterable, Tuple + +import torch +from torch import nn + +from monai.networks.blocks.attention_utils import add_decomposed_rel_pos +from monai.utils.misc import ensure_tuple_size + + +class DecomposedRelativePosEmbedding(nn.Module): + def __init__(self, s_input_dims: Tuple[int, int] | Tuple[int, int, int], c_dim: int, num_heads: int) -> None: + """ + Args: + s_input_dims (Tuple): input spatial dimension. (H, W) or (H, W, D) + c_dim (int): channel dimension + num_heads(int): number of attention heads + """ + super().__init__() + + # validate inputs + if not isinstance(s_input_dims, Iterable) or len(s_input_dims) not in [2, 3]: + raise ValueError("s_input_dims must be set as follows: (H, W) or (H, W, D)") + + self.s_input_dims = s_input_dims + self.c_dim = c_dim + self.num_heads = num_heads + self.rel_pos_arr = nn.ParameterList( + [nn.Parameter(torch.zeros(2 * dim_input_size - 1, c_dim)) for dim_input_size in s_input_dims] + ) + + def forward(self, x: torch.Tensor, att_mat: torch.Tensor, q: torch.Tensor) -> torch.Tensor: + """""" + batch = x.shape[0] + h, w, d = ensure_tuple_size(self.s_input_dims, 3, 1) + + att_mat = add_decomposed_rel_pos( + att_mat.contiguous().view(batch * self.num_heads, h * w * d, h * w * d), + q.contiguous().view(batch * self.num_heads, h * w * d, -1), + self.rel_pos_arr, + (h, w) if d == 1 else (h, w, d), + (h, w) if d == 1 else (h, w, d), + ) + + att_mat = att_mat.reshape(batch, self.num_heads, h * w * d, h * w * d) + return att_mat diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 7c81c1704f..3bef24b4e8 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -11,9 +11,12 @@ from __future__ import annotations +from typing import Optional, Tuple + import torch import torch.nn as nn +from monai.networks.layers.utils import get_rel_pos_embedding_layer from monai.utils import optional_import Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -23,6 +26,7 @@ class SABlock(nn.Module): """ A self-attention block, based on: "Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + One can setup relative positional embedding as described in """ def __init__( @@ -32,6 +36,8 @@ def __init__( dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn: bool = False, + rel_pos_embedding: Optional[str] = None, + input_size: Optional[Tuple] = None, ) -> None: """ Args: @@ -39,6 +45,10 @@ def __init__( num_heads (int): number of attention heads. dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. + rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. + For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. + input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative + positional parameter size. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. """ @@ -62,11 +72,30 @@ def __init__( self.scale = self.head_dim**-0.5 self.save_attn = save_attn self.att_mat = torch.Tensor() + self.rel_positional_embedding = ( + get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads) + if rel_pos_embedding is not None + else None + ) + self.input_size = input_size + + def forward(self, x: torch.Tensor): + """ + Args: + x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C - def forward(self, x): + Return: + torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C + """ output = self.input_rearrange(self.qkv(x)) q, k, v = output[0], output[1], output[2] - att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale + + # apply relative positional embedding if defined + att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + + att_mat = att_mat.softmax(dim=-1) + if self.save_attn: # no gradients and new tensor; # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 4fc2c16f73..29b72a4f37 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -70,7 +70,7 @@ def use_factory(fact_args): from monai.networks.utils import has_nvfuser_instance_norm from monai.utils import ComponentStore, look_up_option, optional_import -__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"] +__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "RelPosEmbedding", "split_args"] class LayerFactory(ComponentStore): @@ -201,6 +201,10 @@ def split_args(args): Conv = LayerFactory(name="Convolution layers", description="Factory for creating convolution layers.") Pool = LayerFactory(name="Pooling layers", description="Factory for creating pooling layers.") Pad = LayerFactory(name="Padding layers", description="Factory for creating padding layers.") +RelPosEmbedding = LayerFactory( + name="Relative positional embedding layers", + description="Factory for creating relative positional embedding factory", +) @Dropout.factory_function("dropout") @@ -468,3 +472,10 @@ def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d | """ types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d) return types[dim - 1] + + +@RelPosEmbedding.factory_function("decomposed") +def decomposed_rel_pos_embedding() -> type[nn.Module]: + from monai.networks.blocks.rel_pos_embedding import DecomposedRelativePosEmbedding + + return DecomposedRelativePosEmbedding diff --git a/monai/networks/layers/utils.py b/monai/networks/layers/utils.py index ace1af27b6..8676f74638 100644 --- a/monai/networks/layers/utils.py +++ b/monai/networks/layers/utils.py @@ -11,9 +11,11 @@ from __future__ import annotations +from typing import Optional + import torch.nn -from monai.networks.layers.factories import Act, Dropout, Norm, Pool, split_args +from monai.networks.layers.factories import Act, Dropout, Norm, Pool, RelPosEmbedding, split_args from monai.utils import has_option __all__ = ["get_norm_layer", "get_act_layer", "get_dropout_layer", "get_pool_layer"] @@ -124,3 +126,14 @@ def get_pool_layer(name: tuple | str, spatial_dims: int | None = 1): pool_name, pool_args = split_args(name) pool_type = Pool[pool_name, spatial_dims] return pool_type(**pool_args) + + +def get_rel_pos_embedding_layer(name: tuple | str, s_input_dims: Optional[tuple], c_dim: int, num_heads: int): + embedding_name, embedding_args = split_args(name) + embedding_type = RelPosEmbedding[embedding_name] + # create a dictionary with the default values which can be overridden by embedding_args + kw_args = {"s_input_dims": s_input_dims, "c_dim": c_dim, "num_heads": num_heads, **embedding_args} + # filter out unused argument names + kw_args = {k: v for k, v in kw_args.items() if has_option(embedding_type, k)} + + return embedding_type(**kw_args) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 6062b5352f..0d0553ed2c 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -20,6 +20,7 @@ from monai.networks import eval_mode from monai.networks.blocks.selfattention import SABlock +from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import einops, has_einops = optional_import("einops") @@ -28,12 +29,20 @@ for dropout_rate in np.linspace(0, 1, 4): for hidden_size in [360, 480, 600, 768]: for num_heads in [4, 6, 8, 12]: - test_case = [ - {"hidden_size": hidden_size, "num_heads": num_heads, "dropout_rate": dropout_rate}, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_SABLOCK.append(test_case) + for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: + for input_size in [(16, 32), (8, 8, 8)]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding, + "input_size": input_size, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_SABLOCK.append(test_case) class TestResBlock(unittest.TestCase): From f87339e35ce32dfce5ba69c698c2ff85102735c2 Mon Sep 17 00:00:00 2001 From: Yufan He <59374597+heyufan1995@users.noreply.github.com> Date: Thu, 7 Dec 2023 21:36:21 -0500 Subject: [PATCH 45/67] Fix swinunetrv2 2D bug (#7302) Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: heyufan1995 Signed-off-by: Mark Graham --- monai/networks/nets/swin_unetr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 10c4ce3d8e..6f96dfd291 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -1024,7 +1024,7 @@ def __init__( self.layers4.append(layer) if self.use_v2: layerc = UnetrBasicBlock( - spatial_dims=3, + spatial_dims=spatial_dims, in_channels=embed_dim * 2**i_layer, out_channels=embed_dim * 2**i_layer, kernel_size=3, From 9c07cb8d1f01feba4115309d3d6e5ee386372ab1 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 12 Dec 2023 11:07:04 +0800 Subject: [PATCH 46/67] Fix `RuntimeError` in `DataAnalyzer` (#7310) Fixes #7309 ### Description `DataAnalyzer` only catch error when data is on GPU, add catching error when data is on CPU. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/apps/auto3dseg/data_analyzer.py | 26 ++++++++++++++++---------- monai/auto3dseg/analyzer.py | 2 +- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py index 9280fb5be5..15e56abfea 100644 --- a/monai/apps/auto3dseg/data_analyzer.py +++ b/monai/apps/auto3dseg/data_analyzer.py @@ -28,7 +28,7 @@ from monai.data import DataLoader, Dataset, partition_dataset from monai.data.utils import no_collation from monai.transforms import Compose, EnsureTyped, LoadImaged, Orientationd -from monai.utils import StrEnum, min_version, optional_import +from monai.utils import ImageMetaKey, StrEnum, min_version, optional_import from monai.utils.enums import DataStatsKeys, ImageStatsKeys @@ -343,19 +343,25 @@ def _get_all_case_stats( d = summarizer(batch_data) except BaseException as err: if "image_meta_dict" in batch_data.keys(): - filename = batch_data["image_meta_dict"]["filename_or_obj"] + filename = batch_data["image_meta_dict"][ImageMetaKey.FILENAME_OR_OBJ] else: - filename = batch_data[self.image_key].meta["filename_or_obj"] + filename = batch_data[self.image_key].meta[ImageMetaKey.FILENAME_OR_OBJ] logger.info(f"Unable to process data {filename} on {device}. {err}") if self.device.type == "cuda": logger.info("DataAnalyzer `device` set to GPU execution hit an exception. Falling back to `cpu`.") - batch_data[self.image_key] = batch_data[self.image_key].to("cpu") - if self.label_key is not None: - label = batch_data[self.label_key] - if not _label_argmax: - label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0] - batch_data[self.label_key] = label.to("cpu") - d = summarizer(batch_data) + try: + batch_data[self.image_key] = batch_data[self.image_key].to("cpu") + if self.label_key is not None: + label = batch_data[self.label_key] + if not _label_argmax: + label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0] + batch_data[self.label_key] = label.to("cpu") + d = summarizer(batch_data) + except BaseException as err: + logger.info(f"Unable to process data {filename} on {device}. {err}") + continue + else: + continue stats_by_cases = { DataStatsKeys.BY_CASE_IMAGE_PATH: d[DataStatsKeys.BY_CASE_IMAGE_PATH], diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py index 654999d439..d5cfb21dab 100644 --- a/monai/auto3dseg/analyzer.py +++ b/monai/auto3dseg/analyzer.py @@ -460,7 +460,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe torch.set_grad_enabled(False) ndas: list[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] # type: ignore - ndas_label: MetaTensor = d[self.label_key] # (H,W,D) + ndas_label: MetaTensor = d[self.label_key].astype(torch.int8) # (H,W,D) if ndas_label.shape != ndas[0].shape: raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") From 7eb1a7d8ff3862874ac1a3d61fcf449f0904b590 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 14 Dec 2023 10:14:19 +0800 Subject: [PATCH 47/67] Support specified filenames in `Saveimage` (#7318) Fixes #7317 ### Description Add support specified filename for users to save like nibabel.save. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/transforms/io/array.py | 17 ++++++++++++++--- tests/test_save_image.py | 16 ++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index cd7e4ef090..7222a26fc3 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -414,6 +414,9 @@ def __init__( self.fname_formatter = output_name_formatter self.output_ext = output_ext.lower() or output_format.lower() + self.output_ext = ( + f".{self.output_ext}" if self.output_ext and not self.output_ext.startswith(".") else self.output_ext + ) if isinstance(writer, str): writer_, has_built_in = optional_import("monai.data", name=f"{writer}") # search built-in if not has_built_in: @@ -458,15 +461,23 @@ def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, writ self.write_kwargs.update(write_kwargs) return self - def __call__(self, img: torch.Tensor | np.ndarray, meta_data: dict | None = None): + def __call__( + self, img: torch.Tensor | np.ndarray, meta_data: dict | None = None, filename: str | PathLike | None = None + ): """ Args: img: target data content that save into file. The image should be channel-first, shape: `[C,H,W,[D]]`. meta_data: key-value pairs of metadata corresponding to the data. + filename: str or file-like object which to save img. + If specified, will ignore `self.output_name_formatter` and `self.folder_layout`. """ meta_data = img.meta if isinstance(img, MetaTensor) else meta_data - kw = self.fname_formatter(meta_data, self) - filename = self.folder_layout.filename(**kw) + if filename is not None: + filename = f"{filename}{self.output_ext}" + else: + kw = self.fname_formatter(meta_data, self) + filename = self.folder_layout.filename(**kw) + if meta_data: meta_spatial_shape = ensure_tuple(meta_data.get("spatial_shape", ())) if len(meta_spatial_shape) >= len(img.shape): diff --git a/tests/test_save_image.py b/tests/test_save_image.py index ba94ab5087..d88db201ce 100644 --- a/tests/test_save_image.py +++ b/tests/test_save_image.py @@ -37,6 +37,8 @@ False, ] +TEST_CASE_5 = [torch.randint(0, 255, (3, 2, 4, 5), dtype=torch.uint8), ".dcm", False] + @unittest.skipUnless(has_itk, "itk not installed") class TestSaveImage(unittest.TestCase): @@ -58,6 +60,20 @@ def test_saved_content(self, test_data, meta_data, output_ext, resample): filepath = "testfile0" if meta_data is not None else "0" self.assertTrue(os.path.exists(os.path.join(tempdir, filepath + "_trans" + output_ext))) + @parameterized.expand([TEST_CASE_5]) + def test_saved_content_with_filename(self, test_data, output_ext, resample): + with tempfile.TemporaryDirectory() as tempdir: + trans = SaveImage( + output_dir=tempdir, + output_ext=output_ext, + resample=resample, + separate_folder=False, # test saving into the same folder + ) + filename = str(os.path.join(tempdir, "test")) + trans(test_data, filename=filename) + + self.assertTrue(os.path.exists(filename + output_ext)) + if __name__ == "__main__": unittest.main() From 426cec952480430e61616b5080b9378750a1dbcf Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 15 Dec 2023 11:32:18 +0800 Subject: [PATCH 48/67] Fix typo (#7321) Fix typo. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/metrics/hausdorff_distance.py | 2 +- monai/metrics/surface_dice.py | 2 +- monai/metrics/surface_distance.py | 2 +- monai/metrics/utils.py | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index d9bbf17db3..d727eb0567 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -190,7 +190,7 @@ def compute_hausdorff_distance( y[b, c], distance_metric=distance_metric, spacing=spacing_list[b], - symetric=not directed, + symmetric=not directed, class_index=c, ) percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances] diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index 635eb1bc24..b20b47a1a5 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -253,7 +253,7 @@ def compute_surface_dice( distance_metric=distance_metric, spacing=spacing_list[b], use_subvoxels=use_subvoxels, - symetric=True, + symmetric=True, class_index=c, ) boundary_correct: int | torch.Tensor | float diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 7ce632c588..3cb336d6a0 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -177,7 +177,7 @@ def compute_average_surface_distance( y[b, c], distance_metric=distance_metric, spacing=spacing_list[b], - symetric=symmetric, + symmetric=symmetric, class_index=c, ) surface_distance = torch.cat(distances) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 62e6520b96..d4b8f6e9b6 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -295,7 +295,7 @@ def get_edge_surface_distance( distance_metric: str = "euclidean", spacing: int | float | np.ndarray | Sequence[int | float] | None = None, use_subvoxels: bool = False, - symetric: bool = False, + symmetric: bool = False, class_index: int = -1, ) -> tuple[ tuple[torch.Tensor, torch.Tensor], @@ -314,7 +314,7 @@ def get_edge_surface_distance( See :py:func:`monai.metrics.utils.get_surface_distance`. use_subvoxels: whether to use subvoxel resolution (using the spacing). This will return the areas of the edges. - symetric: whether to compute the surface distance from `y_pred` to `y` and from `y` to `y_pred`. + symmetric: whether to compute the surface distance from `y_pred` to `y` and from `y` to `y_pred`. class_index: The class-index used for context when warning about empty ground truth or prediction. Returns: @@ -338,7 +338,7 @@ def get_edge_surface_distance( " this may result in nan/inf distance." ) distances: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor] - if symetric: + if symmetric: distances = ( get_surface_distance(edges_pred, edges_gt, distance_metric, spacing), get_surface_distance(edges_gt, edges_pred, distance_metric, spacing), From 4e2160c45f75c3af6b301276e80590a98708f970 Mon Sep 17 00:00:00 2001 From: binliunls <107988372+binliunls@users.noreply.github.com> Date: Fri, 15 Dec 2023 22:00:24 +0800 Subject: [PATCH 49/67] fix optimizer pararmeter issue (#7322) Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: binliu Signed-off-by: Mark Graham --- monai/handlers/mlflow_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py index a2bd345dc6..df209c1c8b 100644 --- a/monai/handlers/mlflow_handler.py +++ b/monai/handlers/mlflow_handler.py @@ -401,7 +401,7 @@ def _default_iteration_log(self, engine: Engine) -> None: cur_optimizer = engine.optimizer for param_name in self.optimizer_param_names: params = { - f"{param_name} group_{i}": float(param_group[param_name]) + f"{param_name}_group_{i}": float(param_group[param_name]) for i, param_group in enumerate(cur_optimizer.param_groups) } self._log_metrics(params, step=engine.state.iteration) From 0e89c92b399e20cff43d2f4c0bf7609299c2ba47 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 18 Dec 2023 12:00:43 +0800 Subject: [PATCH 50/67] Fix `lazy` ignored in `SpatialPadd` (#7316) Fixes #7314 #7315. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Ben Murray Signed-off-by: Mark Graham --- monai/transforms/croppad/dictionary.py | 9 +++------ tests/padders.py | 3 +++ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 56d214c51d..be9441dc4a 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -221,9 +221,8 @@ def __init__( note that `np.pad` treats channel dimension as the first dimension. """ - LazyTransform.__init__(self, lazy) padder = SpatialPad(spatial_size, method, lazy=lazy, **kwargs) - Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) + Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy) class BorderPadd(Padd): @@ -274,9 +273,8 @@ def __init__( note that `np.pad` treats channel dimension as the first dimension. """ - LazyTransform.__init__(self, lazy) padder = BorderPad(spatial_border=spatial_border, lazy=lazy, **kwargs) - Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) + Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy) class DivisiblePadd(Padd): @@ -324,9 +322,8 @@ def __init__( See also :py:class:`monai.transforms.SpatialPad` """ - LazyTransform.__init__(self, lazy) padder = DivisiblePad(k=k, method=method, lazy=lazy, **kwargs) - Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) + Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy) class Cropd(MapTransform, InvertibleTransform, LazyTransform): diff --git a/tests/padders.py b/tests/padders.py index 02d7b40af6..ae1153bdfd 100644 --- a/tests/padders.py +++ b/tests/padders.py @@ -136,6 +136,9 @@ def pad_test_pending_ops(self, input_param, input_shape): # TODO: mode="bilinear" may report error overrides = {"mode": "nearest", "padding_mode": mode[1], "align_corners": False} result = apply_pending(pending_result, overrides=overrides)[0] + # lazy in constructor + pad_fn_lazy = self.Padder(mode=mode[0], lazy=True, **input_param) + self.assertTrue(pad_fn_lazy.lazy) # compare assert_allclose(result, expected, rtol=1e-5) if isinstance(result, MetaTensor) and not isinstance(pad_fn, MapTransform): From 0a03228ed46cc7b806a591b70f08f967a69af9b3 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 28 Dec 2023 22:22:52 +0800 Subject: [PATCH 51/67] Update openslide-python version (#7344) Signed-off-by: Mark Graham --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index cacbefe234..2639c0a3e7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -27,7 +27,7 @@ ninja torchvision psutil cucim>=23.2.0; platform_system == "Linux" -openslide-python==1.1.2 +openslide-python imagecodecs; platform_system == "Linux" or platform_system == "Darwin" tifffile; platform_system == "Linux" or platform_system == "Darwin" pandas From 775ecac72784f3cc99701c6106bedbca7ee98320 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 29 Dec 2023 12:33:46 +0800 Subject: [PATCH 52/67] Upgrade the version of `transformers` (#7343) Fixes #7338 ### Description transformers' version is pinned to v4.22 since https://github.com/Project-MONAI/MONAI/issues/5157. Updated the version refer to https://github.com/huggingface/transformers/issues/21678. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/networks/nets/transchex.py | 49 +++++++++----------------------- requirements-dev.txt | 2 +- tests/test_transchex.py | 3 +- 3 files changed, 15 insertions(+), 39 deletions(-) diff --git a/monai/networks/nets/transchex.py b/monai/networks/nets/transchex.py index ff27903cef..6bfff3c956 100644 --- a/monai/networks/nets/transchex.py +++ b/monai/networks/nets/transchex.py @@ -12,20 +12,17 @@ from __future__ import annotations import math -import os -import shutil -import tarfile -import tempfile from collections.abc import Sequence import torch from torch import nn +from monai.config.type_definitions import PathLike from monai.utils import optional_import transformers = optional_import("transformers") load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert")[0] -cached_path = optional_import("transformers.file_utils", name="cached_path")[0] +cached_file = optional_import("transformers.utils", name="cached_file")[0] BertEmbeddings = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings")[0] BertLayer = optional_import("transformers.models.bert.modeling_bert", name="BertLayer")[0] @@ -63,44 +60,16 @@ def from_pretrained( state_dict=None, cache_dir=None, from_tf=False, + path_or_repo_id="bert-base-uncased", + filename="pytorch_model.bin", *inputs, **kwargs, ): - archive_file = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz" - resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) - tempdir = None - if os.path.isdir(resolved_archive_file) or from_tf: - serialization_dir = resolved_archive_file - else: - tempdir = tempfile.mkdtemp() - with tarfile.open(resolved_archive_file, "r:gz") as archive: - - def is_within_directory(directory, target): - abs_directory = os.path.abspath(directory) - abs_target = os.path.abspath(target) - - prefix = os.path.commonprefix([abs_directory, abs_target]) - - return prefix == abs_directory - - def safe_extract(tar, path=".", members=None, *, numeric_owner=False): - for member in tar.getmembers(): - member_path = os.path.join(path, member.name) - if not is_within_directory(path, member_path): - raise Exception("Attempted Path Traversal in Tar File") - - tar.extractall(path, members, numeric_owner=numeric_owner) - - safe_extract(archive, tempdir) - serialization_dir = tempdir + weights_path = cached_file(path_or_repo_id, filename, cache_dir=cache_dir) model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs) if state_dict is None and not from_tf: - weights_path = os.path.join(serialization_dir, "pytorch_model.bin") state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None) - if tempdir: - shutil.rmtree(tempdir) if from_tf: - weights_path = os.path.join(serialization_dir, "model.ckpt") return load_tf_weights_in_bert(model, weights_path) old_keys = [] new_keys = [] @@ -304,6 +273,8 @@ def __init__( chunk_size_feed_forward: int = 0, is_decoder: bool = False, add_cross_attention: bool = False, + path_or_repo_id: str | PathLike = "bert-base-uncased", + filename: str = "pytorch_model.bin", ) -> None: """ Args: @@ -315,6 +286,10 @@ def __init__( num_vision_layers: number of vision transformer layers. num_mixed_layers: number of mixed transformer layers. drop_out: fraction of the input units to drop. + path_or_repo_id: This can be either: + - a string, the *model id* of a model repo on huggingface.co. + - a path to a *directory* potentially containing the file. + filename: The name of the file to locate in `path_or_repo`. The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`. @@ -369,6 +344,8 @@ def __init__( num_vision_layers=num_vision_layers, num_mixed_layers=num_mixed_layers, bert_config=bert_config, + path_or_repo_id=path_or_repo_id, + filename=filename, ) self.patch_size = patch_size diff --git a/requirements-dev.txt b/requirements-dev.txt index 2639c0a3e7..4685cd1572 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -33,7 +33,7 @@ tifffile; platform_system == "Linux" or platform_system == "Darwin" pandas requests einops -transformers<4.22; python_version <= '3.10' # https://github.com/Project-MONAI/MONAI/issues/5157 +transformers>=4.36.0 mlflow>=1.28.0 clearml>=1.10.0rc0 matplotlib!=3.5.0 diff --git a/tests/test_transchex.py b/tests/test_transchex.py index 8fb1f56715..9ad847cdaa 100644 --- a/tests/test_transchex.py +++ b/tests/test_transchex.py @@ -18,7 +18,7 @@ from monai.networks import eval_mode from monai.networks.nets.transchex import Transchex -from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_quick +from tests.utils import skip_if_quick TEST_CASE_TRANSCHEX = [] for drop_out in [0.4]: @@ -46,7 +46,6 @@ @skip_if_quick -@SkipIfAtLeastPyTorchVersion((1, 10)) class TestTranschex(unittest.TestCase): @parameterized.expand(TEST_CASE_TRANSCHEX) def test_shape(self, input_param, expected_shape): From bc3c0621888a90fa3416e154fd461e3a5b8e8456 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 2 Jan 2024 03:34:50 +0000 Subject: [PATCH 53/67] Bump github/codeql-action from 2 to 3 (#7354) Bumps [github/codeql-action](https://github.com/github/codeql-action) from 2 to 3.
Release notes

Sourced from github/codeql-action's releases.

CodeQL Bundle v2.15.5

Bundles CodeQL CLI v2.15.5

Includes the following CodeQL language packs from github/codeql@codeql-cli/v2.15.5:

CodeQL Bundle v2.15.4

Bundles CodeQL CLI v2.15.4

Includes the following CodeQL language packs from github/codeql@codeql-cli/v2.15.4:

CodeQL Bundle

Bundles CodeQL CLI v2.15.3

Includes the following CodeQL language packs from github/codeql@codeql-cli/v2.15.3:

... (truncated)

Changelog

Sourced from github/codeql-action's changelog.

Commits
  • e0c2b0a change version numbers inside processing function as well
  • 8e4a6c7 improve handling of changelog processing for backports
  • 511f073 Merge pull request #2033 from github/dependabot/npm_and_yarn/npm-0a98872b3d
  • ebf5a83 Merge pull request #2035 from github/mergeback/v3.22.11-to-main-b374143c
  • 7813bda Update checked-in dependencies
  • 2b2fb6b Update changelog and version after v3.22.11
  • b374143 Merge pull request #2034 from github/update-v3.22.11-64e61baea
  • 95591ba Merge branch 'main' into dependabot/npm_and_yarn/npm-0a98872b3d
  • e2b5cc7 Update changelog for v3.22.11
  • See full diff in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github/codeql-action&package-manager=github_actions&previous-version=2&new-version=3)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: Mark Graham --- .github/workflows/codeql-analysis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 3d32ae407a..18f1519b5a 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -42,7 +42,7 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -72,4 +72,4 @@ jobs: BUILD_MONAI=1 ./runtests.sh --build - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 From 9d78d9958b168137498afe2692f24b7a86449999 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 2 Jan 2024 06:42:41 +0000 Subject: [PATCH 54/67] Bump actions/upload-artifact from 3 to 4 (#7350) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 3 to 4.
Release notes

Sourced from actions/upload-artifact's releases.

v4.0.0

What's Changed

The release of upload-artifact@v4 and download-artifact@v4 are major changes to the backend architecture of Artifacts. They have numerous performance and behavioral improvements.

For more information, see the @​actions/artifact documentation.

New Contributors

Full Changelog: https://github.com/actions/upload-artifact/compare/v3...v4.0.0

v3.1.3

What's Changed

Full Changelog: https://github.com/actions/upload-artifact/compare/v3...v3.1.3

v3.1.2

  • Update all @actions/* NPM packages to their latest versions- #374
  • Update all dev dependencies to their most recent versions - #375

v3.1.1

  • Update actions/core package to latest version to remove set-output deprecation warning #351

v3.1.0

What's Changed

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=actions/upload-artifact&package-manager=github_actions&previous-version=3&new-version=4)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: Mark Graham --- .github/workflows/docker.yml | 2 +- .github/workflows/release.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index f51e4fdf76..f80a4c2c96 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -36,7 +36,7 @@ jobs: python setup.py build cat build/lib/monai/_version.py - name: Upload version - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: _version.py path: build/lib/monai/_version.py diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7197215486..e9817e1c4c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -66,7 +66,7 @@ jobs: - if: matrix.python-version == '3.9' && startsWith(github.ref, 'refs/tags/') name: Upload artifacts - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: dist path: dist/ @@ -108,7 +108,7 @@ jobs: python setup.py build cat build/lib/monai/_version.py - name: Upload version - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: _version.py path: build/lib/monai/_version.py From 535830e756f7568cbf3adab52cf65608c7732a22 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 2 Jan 2024 17:29:09 +0800 Subject: [PATCH 55/67] Bump actions/setup-python from 4 to 5 (#7351) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [actions/setup-python](https://github.com/actions/setup-python) from 4 to 5.
Release notes

Sourced from actions/setup-python's releases.

v5.0.0

What's Changed

In scope of this release, we update node version runtime from node16 to node20 (actions/setup-python#772). Besides, we update dependencies to the latest versions.

Full Changelog: https://github.com/actions/setup-python/compare/v4.8.0...v5.0.0

v4.8.0

What's Changed

In scope of this release we added support for GraalPy (actions/setup-python#694). You can use this snippet to set up GraalPy:

steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
  with:
    python-version: 'graalpy-22.3'
- run: python my_script.py

Besides, the release contains such changes as:

New Contributors

Full Changelog: https://github.com/actions/setup-python/compare/v4...v4.8.0

v4.7.1

What's Changed

Full Changelog: https://github.com/actions/setup-python/compare/v4...v4.7.1

v4.7.0

In scope of this release, the support for reading python version from pyproject.toml was added (actions/setup-python#669).

      - name: Setup Python
        uses: actions/setup-python@v4
</tr></table>

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=actions/setup-python&package-manager=github_actions&previous-version=4&new-version=5)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: Mark Graham --- .github/workflows/cron-ngc-bundle.yml | 2 +- .github/workflows/docker.yml | 2 +- .github/workflows/pythonapp-min.yml | 6 +++--- .github/workflows/pythonapp.yml | 8 ++++---- .github/workflows/release.yml | 4 ++-- .github/workflows/setupapp.yml | 4 ++-- .github/workflows/weekly-preview.yml | 2 +- 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/.github/workflows/cron-ngc-bundle.yml b/.github/workflows/cron-ngc-bundle.yml index 0bba630d03..84666204a9 100644 --- a/.github/workflows/cron-ngc-bundle.yml +++ b/.github/workflows/cron-ngc-bundle.yml @@ -19,7 +19,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index f80a4c2c96..c375e82e74 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -26,7 +26,7 @@ jobs: ref: dev fetch-depth: 0 - name: Set up Python 3.9 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.9' - shell: bash diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml index 558c270e33..7b7930bdf5 100644 --- a/.github/workflows/pythonapp-min.yml +++ b/.github/workflows/pythonapp-min.yml @@ -30,7 +30,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: Prepare pip wheel @@ -76,7 +76,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Prepare pip wheel @@ -121,7 +121,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: Prepare pip wheel diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index ad8b555dd4..29a79759e0 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -28,7 +28,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp @@ -69,7 +69,7 @@ jobs: disk-root: "D:" - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: Prepare pip wheel @@ -128,7 +128,7 @@ jobs: with: fetch-depth: 0 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp @@ -209,7 +209,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e9817e1c4c..9334908bfc 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -19,7 +19,7 @@ jobs: with: fetch-depth: 0 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install setuptools @@ -97,7 +97,7 @@ jobs: with: fetch-depth: 0 - name: Set up Python 3.9 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.9' - shell: bash diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index 0ff7162bee..82394a86dd 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -83,7 +83,7 @@ jobs: with: fetch-depth: 0 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: cache weekly timestamp @@ -120,7 +120,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp diff --git a/.github/workflows/weekly-preview.yml b/.github/workflows/weekly-preview.yml index c631982745..e94e1dac5a 100644 --- a/.github/workflows/weekly-preview.yml +++ b/.github/workflows/weekly-preview.yml @@ -14,7 +14,7 @@ jobs: ref: dev fetch-depth: 0 - name: Set up Python 3.9 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.9' - name: Install setuptools From 488c335af952648de829a764f55544a11b04e630 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 2 Jan 2024 23:42:25 +0800 Subject: [PATCH 56/67] Bump actions/download-artifact from 3 to 4 (#7352) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 3 to 4.
Release notes

Sourced from actions/download-artifact's releases.

v4.0.0

What's Changed

The release of upload-artifact@v4 and download-artifact@v4 are major changes to the backend architecture of Artifacts. They have numerous performance and behavioral improvements.

For more information, see the @​actions/artifact documentation.

New Contributors

Full Changelog: https://github.com/actions/download-artifact/compare/v3...v4.0.0

v3.0.2

  • Bump @actions/artifact to v1.1.1 - actions/download-artifact#195
  • Fixed a bug in Node16 where if an HTTP download finished too quickly (<1ms, e.g. when it's mocked) we attempt to delete a temp file that has not been created yet actions/toolkit#1278

v3.0.1

Commits
  • f44cd7b Merge pull request #259 from actions/robherley/glob-downloads
  • 3181fe8 add some migration docs
  • aaaac7b licensed cache
  • 7c9182f update readme
  • b94e701 licensed cache
  • 0b55470 add test case for globbed downloads to same directory
  • 0b51c2e update prettier/eslint versions
  • c4c6db7 support globbing artifact list & merging download directory
  • 1bd0606 Merge pull request #252 from stchr/patch-1
  • eff4d42 fix default for run-id
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=actions/download-artifact&package-manager=github_actions&previous-version=3&new-version=4)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: Mark Graham --- .github/workflows/docker.yml | 2 +- .github/workflows/release.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index c375e82e74..229ae675f5 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -56,7 +56,7 @@ jobs: with: ref: dev - name: Download version - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: _version.py - name: docker_build diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9334908bfc..a03d2cea6c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -125,7 +125,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Download version - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: _version.py - name: Set tag From 41abb20e7a9c5b54b353a0f20a7544ab70704ff4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 3 Jan 2024 10:41:55 +0800 Subject: [PATCH 57/67] Bump peter-evans/slash-command-dispatch from 3.0.1 to 3.0.2 (#7353) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [peter-evans/slash-command-dispatch](https://github.com/peter-evans/slash-command-dispatch) from 3.0.1 to 3.0.2.
Release notes

Sourced from peter-evans/slash-command-dispatch's releases.

Slash Command Dispatch v3.0.2

What's Changed

New Contributors

Full Changelog: https://github.com/peter-evans/slash-command-dispatch/compare/v3.0.1...v3.0.2

Commits
  • f996d7b Fix the CollaboratorPermission GraphQL query (#301)
  • 05b97d6 build(deps-dev): bump @​types/node from 16.18.65 to 16.18.67 (#300)
  • 8e70073 build(deps-dev): bump eslint from 8.54.0 to 8.55.0 (#299)
  • bd00135 build(deps-dev): bump @​types/node from 16.18.62 to 16.18.65 (#298)
  • ee873b6 build(deps-dev): bump eslint from 8.53.0 to 8.54.0 (#296)
  • 44abc47 build(deps-dev): bump @​types/node from 16.18.61 to 16.18.62 (#295)
  • 19ad7b8 build(deps-dev): bump @​types/node from 16.18.60 to 16.18.61 (#294)
  • 29a9815 build(deps-dev): bump prettier from 3.0.3 to 3.1.0 (#293)
  • ade0309 build(deps-dev): bump eslint from 8.52.0 to 8.53.0 (#292)
  • fc8222e build(deps-dev): bump @​types/node from 16.18.59 to 16.18.60 (#291)
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=peter-evans/slash-command-dispatch&package-manager=github_actions&previous-version=3.0.1&new-version=3.0.2)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: Mark Graham --- .github/workflows/chatops.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/chatops.yml b/.github/workflows/chatops.yml index b4e201a0d9..59c7d070b4 100644 --- a/.github/workflows/chatops.yml +++ b/.github/workflows/chatops.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest steps: - name: dispatch - uses: peter-evans/slash-command-dispatch@v3.0.1 + uses: peter-evans/slash-command-dispatch@v3.0.2 with: token: ${{ secrets.PR_MAINTAIN }} reaction-token: ${{ secrets.GITHUB_TOKEN }} From 5ce1b60bdb97b6e80261f6bc2d78963ad5460b0c Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 8 Jan 2024 11:18:45 +0800 Subject: [PATCH 58/67] Give more useful exception when batch is considered during matrix multiplication (#7326) Fixes #7323 ### Description Give more useful exception when batch is considered during matrix multiplication. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/transforms/inverse.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 41fabb35aa..f94f11eca9 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -185,7 +185,17 @@ def track_transform_meta( # not lazy evaluation, directly update the metatensor affine (don't push to the stack) orig_affine = data_t.peek_pending_affine() orig_affine = convert_to_dst_type(orig_affine, affine, dtype=torch.float64)[0] - affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64) + try: + affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64) + except RuntimeError as e: + if orig_affine.ndim > 2: + if data_t.is_batch: + msg = "Transform applied to batched tensor, should be applied to instances only" + else: + msg = "Mismatch affine matrix, ensured that the batch dimension is not included in the calculation." + raise RuntimeError(msg) from e + else: + raise out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu"), dtype=torch.float64) if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)): From 64577f95f3a938a5eb74415335c78573db3ea193 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 9 Jan 2024 11:13:14 +0800 Subject: [PATCH 59/67] Fix incorrectly size compute in auto3dseg analyzer (#7374) Fixes #7222 ### Description remove int convert here. https://github.com/Project-MONAI/MONAI/blob/8fa6931b14ba9617a595fff1d396ac44cc82e207/monai/auto3dseg/analyzer.py#L259 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/auto3dseg/analyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py index d5cfb21dab..56419da4cb 100644 --- a/monai/auto3dseg/analyzer.py +++ b/monai/auto3dseg/analyzer.py @@ -256,7 +256,7 @@ def __call__(self, data): ) report[ImageStatsKeys.SIZEMM] = [ - int(a * b) for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING]) + a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING]) ] report[ImageStatsKeys.INTENSITY] = [ From ea0e89be9fda0a073229ea36ca769d441a528882 Mon Sep 17 00:00:00 2001 From: Kaibo Tang Date: Tue, 9 Jan 2024 22:41:30 -0500 Subject: [PATCH 60/67] 7380 mention demo in bending energy and diffusion docstrings (#7381) Fixes #7380. ### Description Mention [demo](https://github.com/Project-MONAI/tutorials/blob/main/modules/bending_energy_diffusion_loss_notes.ipynb) in bending energy and diffusion docstrings. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: kaibo Signed-off-by: Mark Graham --- monai/losses/deform.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index 129abeedd2..37e4468d4b 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -46,7 +46,10 @@ def spatial_gradient(x: torch.Tensor, dim: int) -> torch.Tensor: class BendingEnergyLoss(_Loss): """ - Calculate the bending energy based on second-order differentiation of pred using central finite difference. + Calculate the bending energy based on second-order differentiation of ``pred`` using central finite difference. + + For more information, + see https://github.com/Project-MONAI/tutorials/blob/main/modules/bending_energy_diffusion_loss_notes.ipynb. Adapted from: DeepReg (https://github.com/DeepRegNet/DeepReg) @@ -75,6 +78,9 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + ValueError: When ``pred`` is not 3-d, 4-d or 5-d. + ValueError: When any spatial dimension of ``pred`` has size less than or equal to 4. + ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions. """ if pred.ndim not in [3, 4, 5]: @@ -84,7 +90,8 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: raise ValueError(f"All spatial dimensions must be > 4, got spatial dimensions {pred.shape[2:]}") if pred.shape[1] != pred.ndim - 2: raise ValueError( - f"Number of vector components, {pred.shape[1]}, does not match number of spatial dimensions, {pred.ndim-2}" + f"Number of vector components, i.e. number of channels of the input DDF, {pred.shape[1]}, " + f"does not match number of spatial dimensions, {pred.ndim - 2}" ) # first order gradient @@ -120,12 +127,15 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: class DiffusionLoss(_Loss): """ - Calculate the diffusion based on first-order differentiation of pred using central finite difference. + Calculate the diffusion based on first-order differentiation of ``pred`` using central finite difference. For the original paper, please refer to VoxelMorph: A Learning Framework for Deformable Medical Image Registration, Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231. + For more information, + see https://github.com/Project-MONAI/tutorials/blob/main/modules/bending_energy_diffusion_loss_notes.ipynb. + Adapted from: VoxelMorph (https://github.com/voxelmorph/voxelmorph) """ From fff5cddfe9e4867c2a02a4f7ba5e94cd750870d9 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 12 Jan 2024 20:25:42 +0800 Subject: [PATCH 61/67] Pin gdown version to v4.6.3 (#7384) Workaround for #7382 #7383 ### Description Based on the comment [here](https://github.com/wkentaro/gdown/issues/291#issuecomment-1887060708), pin the gdown version as a workaround. Will review this one once gdown has some update internal. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/apps/utils.py | 2 +- requirements-dev.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/utils.py b/monai/apps/utils.py index d2dd63b958..442dbabba0 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -30,7 +30,7 @@ from monai.config.type_definitions import PathLike from monai.utils import look_up_option, min_version, optional_import -gdown, has_gdown = optional_import("gdown", "4.4") +gdown, has_gdown = optional_import("gdown", "4.6.3") if TYPE_CHECKING: from tqdm import tqdm diff --git a/requirements-dev.txt b/requirements-dev.txt index 4685cd1572..f8bc9d5a3e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,7 @@ # Full requirements for developments -r requirements-min.txt pytorch-ignite==0.4.11 -gdown>=4.4.0 +gdown>=4.4.0, <=4.6.3 scipy>=1.7.1 itk>=5.2 nibabel From 6556f88c887f0919f033b07c581a079b51c65cea Mon Sep 17 00:00:00 2001 From: "axel.vlaminck" Date: Thu, 18 Jan 2024 18:07:22 +0100 Subject: [PATCH 62/67] Track applied operations in image filter (#7395) Fixes #7394 ### Description When ImageFilter is in the transformation sequence it didn't pass the applied_operations. Now it is passed when present. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: axel.vlaminck Signed-off-by: Mark Graham --- monai/transforms/utility/array.py | 11 ++++++++--- tests/test_image_filter.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 2322f2123f..5dfbcb0e91 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1562,17 +1562,22 @@ def __init__(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int | self.filter_size = filter_size self.additional_args_for_filter = kwargs - def __call__(self, img: NdarrayOrTensor, meta_dict: dict | None = None) -> NdarrayOrTensor: + def __call__( + self, img: NdarrayOrTensor, meta_dict: dict | None = None, applied_operations: list | None = None + ) -> NdarrayOrTensor: """ Args: img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]] meta_dict: An optional dictionary with metadata + applied_operations: An optional list of operations that have been applied to the data Returns: A MetaTensor with the same shape as `img` and identical metadata """ if isinstance(img, MetaTensor): meta_dict = img.meta + applied_operations = img.applied_operations + img_, prev_type, device = convert_data_type(img, torch.Tensor) ndim = img_.ndim - 1 # assumes channel first format @@ -1582,8 +1587,8 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: dict | None = None) -> Ndarr self.filter = ApplyFilter(self.filter) img_ = self._apply_filter(img_) - if meta_dict: - img_ = MetaTensor(img_, meta=meta_dict) + if meta_dict is not None or applied_operations is not None: + img_ = MetaTensor(img_, meta=meta_dict, applied_operations=applied_operations) else: img_, *_ = convert_data_type(img_, prev_type, device) return img_ diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index 841a5d5cd5..985ea95e79 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -17,6 +17,7 @@ import torch from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms import ImageFilter, ImageFilterd, RandImageFilter, RandImageFilterd @@ -115,6 +116,21 @@ def test_call_3d(self, filter_name): out_tensor = filter(SAMPLE_IMAGE_3D) self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_3D.shape[1:]) + def test_pass_applied_operations(self): + "Test that applied operations are passed through" + applied_operations = ["op1", "op2"] + image = MetaTensor(SAMPLE_IMAGE_2D, applied_operations=applied_operations) + filter = ImageFilter(SUPPORTED_FILTERS[0], 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(image) + self.assertEqual(out_tensor.applied_operations, applied_operations) + + def test_pass_empty_metadata_dict(self): + "Test that applied operations are passed through" + image = MetaTensor(SAMPLE_IMAGE_2D, meta={}) + filter = ImageFilter(SUPPORTED_FILTERS[0], 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(image) + self.assertTrue(isinstance(out_tensor, MetaTensor)) + class TestImageFilterDict(unittest.TestCase): @parameterized.expand(SUPPORTED_FILTERS) From ef912c84160477aa80b8a41052325663ae070829 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 19 Jan 2024 16:00:13 +0800 Subject: [PATCH 63/67] Add `compile` support in `SupervisedTrainer` and `SupervisedEvaluator` (#7375) Fixes # . ### Description Add `compile` support in `SupervisedTrainer` and `SupervisedEvaluator`. Convert to `torch.Tensor` internally. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/engines/evaluator.py | 51 +++++++++++++++++++++++++++++++++++-- monai/engines/trainer.py | 52 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 99 insertions(+), 4 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 119853d5c5..2c8dfe6b85 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -11,12 +11,14 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import torch from torch.utils.data import DataLoader from monai.config import IgniteInfo, KeysCollection +from monai.data import MetaTensor from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer @@ -25,7 +27,7 @@ from monai.utils import ForwardMode, ensure_tuple, min_version, optional_import from monai.utils.enums import CommonKeys as Keys from monai.utils.enums import EngineStatsKeys as ESKeys -from monai.utils.module import look_up_option +from monai.utils.module import look_up_option, pytorch_after if TYPE_CHECKING: from ignite.engine import Engine, EventEnum @@ -213,6 +215,10 @@ class SupervisedEvaluator(Evaluator): `device`, `non_blocking`. amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to + `torch.Tensor` before forward pass, then converted back afterward with copied meta information. + compile_kwargs: dict of the args for `torch.compile()` API, for more details: + https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile. """ @@ -238,6 +244,8 @@ def __init__( decollate: bool = True, to_kwargs: dict | None = None, amp_kwargs: dict | None = None, + compile: bool = False, + compile_kwargs: dict | None = None, ) -> None: super().__init__( device=device, @@ -259,8 +267,16 @@ def __init__( to_kwargs=to_kwargs, amp_kwargs=amp_kwargs, ) - + if compile: + if pytorch_after(2, 1): + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + network = torch.compile(network, **compile_kwargs) # type: ignore[assignment] + else: + warnings.warn( + "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done" + ) self.network = network + self.compile = compile self.inferer = SimpleInferer() if inferer is None else inferer def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Tensor]) -> dict: @@ -288,6 +304,24 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten kwargs: dict = {} else: inputs, targets, args, kwargs = batch + # FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026 + if self.compile: + inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None + if isinstance(inputs, MetaTensor): + warnings.warn( + "Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass." + ) + inputs, inputs_meta, inputs_applied_operations = ( + inputs.as_tensor(), + inputs.meta, + inputs.applied_operations, + ) + if isinstance(targets, MetaTensor): + targets, targets_meta, targets_applied_operations = ( + targets.as_tensor(), + targets.meta, + targets.applied_operations, + ) # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} @@ -298,6 +332,19 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) else: engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) + # copy back meta info + if self.compile: + if inputs_meta is not None: + engine.state.output[Keys.IMAGE] = MetaTensor( + inputs, meta=inputs_meta, applied_operations=inputs_applied_operations + ) + engine.state.output[Keys.PRED] = MetaTensor( + engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations + ) + if targets_meta is not None: + engine.state.output[Keys.LABEL] = MetaTensor( + targets, meta=targets_meta, applied_operations=targets_applied_operations + ) engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index dec5450725..c1364fe015 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -11,6 +11,7 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import torch @@ -18,6 +19,7 @@ from torch.utils.data import DataLoader from monai.config import IgniteInfo +from monai.data import MetaTensor from monai.engines.utils import IterationEvents, default_make_latent, default_metric_cmp_fn, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer @@ -25,6 +27,7 @@ from monai.utils import AdversarialIterationEvents, AdversarialKeys, GanKeys, min_version, optional_import from monai.utils.enums import CommonKeys as Keys from monai.utils.enums import EngineStatsKeys as ESKeys +from monai.utils.module import pytorch_after if TYPE_CHECKING: from ignite.engine import Engine, EventEnum @@ -125,7 +128,10 @@ class SupervisedTrainer(Trainer): `device`, `non_blocking`. amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. - + compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to + `torch.Tensor` before forward pass, then converted back afterward with copied meta information. + compile_kwargs: dict of the args for `torch.compile()` API, for more details: + https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile. """ def __init__( @@ -153,6 +159,8 @@ def __init__( optim_set_to_none: bool = False, to_kwargs: dict | None = None, amp_kwargs: dict | None = None, + compile: bool = False, + compile_kwargs: dict | None = None, ) -> None: super().__init__( device=device, @@ -174,8 +182,16 @@ def __init__( to_kwargs=to_kwargs, amp_kwargs=amp_kwargs, ) - + if compile: + if pytorch_after(2, 1): + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + network = torch.compile(network, **compile_kwargs) # type: ignore[assignment] + else: + warnings.warn( + "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done" + ) self.network = network + self.compile = compile self.optimizer = optimizer self.loss_function = loss_function self.inferer = SimpleInferer() if inferer is None else inferer @@ -207,6 +223,25 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso kwargs: dict = {} else: inputs, targets, args, kwargs = batch + # FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026 + if self.compile: + inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None + if isinstance(inputs, MetaTensor): + warnings.warn( + "Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass." + ) + inputs, inputs_meta, inputs_applied_operations = ( + inputs.as_tensor(), + inputs.meta, + inputs.applied_operations, + ) + if isinstance(targets, MetaTensor): + targets, targets_meta, targets_applied_operations = ( + targets.as_tensor(), + targets.meta, + targets.applied_operations, + ) + # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} @@ -231,6 +266,19 @@ def _compute_pred_loss(): engine.state.output[Keys.LOSS].backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) engine.optimizer.step() + # copy back meta info + if self.compile: + if inputs_meta is not None: + engine.state.output[Keys.IMAGE] = MetaTensor( + inputs, meta=inputs_meta, applied_operations=inputs_applied_operations + ) + engine.state.output[Keys.PRED] = MetaTensor( + engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations + ) + if targets_meta is not None: + engine.state.output[Keys.LABEL] = MetaTensor( + targets, meta=targets_meta, applied_operations=targets_applied_operations + ) engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output From 21f9ae12ceb9980b6d40ec5c9c70581eceb059fd Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 22 Jan 2024 23:58:29 +0800 Subject: [PATCH 64/67] Fix CUDA_VISIBLE_DEVICES setting ignored (#7408) Fixes #7407 ### Description Move `optional import cucim` inside the function to avoid using all GPUs. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/metrics/utils.py | 6 ++---- tests/test_set_visible_devices.py | 7 +++++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index d4b8f6e9b6..e7057256fb 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -38,10 +38,6 @@ binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion") distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt") distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") -cucim_binary_erosion, has_cucim_binary_erosion = optional_import("cucim.skimage.morphology", name="binary_erosion") -cucim_distance_transform_edt, has_cucim_distance_transform_edt = optional_import( - "cucim.core.operations.morphology", name="distance_transform_edt" -) __all__ = [ "ignore_background", @@ -179,6 +175,8 @@ def get_mask_edges( always_return_as_numpy: whether to a numpy array regardless of the input type. If False, return the same type as inputs. """ + # move in the funciton to avoid using all the GPUs + cucim_binary_erosion, has_cucim_binary_erosion = optional_import("cucim.skimage.morphology", name="binary_erosion") if seg_pred.shape != seg_gt.shape: raise ValueError(f"seg_pred and seg_gt should have same shapes, got {seg_pred.shape} and {seg_gt.shape}.") converter: Any diff --git a/tests/test_set_visible_devices.py b/tests/test_set_visible_devices.py index 53703e107a..993e8a4ac2 100644 --- a/tests/test_set_visible_devices.py +++ b/tests/test_set_visible_devices.py @@ -35,6 +35,13 @@ def test_visible_devices(self): ) self.assertEqual(num_gpus_before, num_gpus_after) + # test import monai won't affect setting CUDA_VISIBLE_DEVICES + num_gpus_after_monai = self.run_process_and_get_exit_code( + 'python -c "import os; import torch; import monai; ' + + "os.environ['CUDA_VISIBLE_DEVICES'] = '0'; exit(torch.cuda.device_count())\"" + ) + self.assertEqual(num_gpus_after_monai, 1) + if __name__ == "__main__": unittest.main() From 56212b9955f991227c2f6280b436140279043988 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 25 Jan 2024 09:47:36 +0800 Subject: [PATCH 65/67] Fix Incorrect updated affine in `NrrdReader` and update docstring in `ITKReader` (#7415) Fixes #7414 Fixes #7371 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/data/image_reader.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 0823d11834..2361bb63a7 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -168,8 +168,8 @@ class ITKReader(ImageReader): series_name: the name of the DICOM series if there are multiple ones. used when loading DICOM series. reverse_indexing: whether to use a reversed spatial indexing convention for the returned data array. - If ``False``, the spatial indexing follows the numpy convention; - otherwise, the spatial indexing convention is reversed to be compatible with ITK. Default is ``False``. + If ``False``, the spatial indexing convention is reversed to be compatible with ITK; + otherwise, the spatial indexing follows the numpy convention. Default is ``False``. This option does not affect the metadata. series_meta: whether to load the metadata of the DICOM series (using the metadata from the first slice). This flag is checked only when loading DICOM series. Default is ``False``. @@ -1323,7 +1323,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: header = dict(i.header) if self.index_order == "C": header = self._convert_f_to_c_order(header) - header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i) + header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(header) if self.affine_lps_to_ras: header = self._switch_lps_ras(header) @@ -1344,7 +1344,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: return _stack_images(img_array, compatible_meta), compatible_meta - def _get_affine(self, img: NrrdImage) -> np.ndarray: + def _get_affine(self, header: dict) -> np.ndarray: """ Get the affine matrix of the image, it can be used to correct spacing, orientation or execute spatial transforms. @@ -1353,8 +1353,8 @@ def _get_affine(self, img: NrrdImage) -> np.ndarray: img: A `NrrdImage` loaded from image file """ - direction = img.header["space directions"] - origin = img.header["space origin"] + direction = header["space directions"] + origin = header["space origin"] x, y = direction.shape affine_diam = min(x, y) + 1 From 4a9d10bab9808acb0b625535ec2d3f0b4a8aa8c6 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 30 Jan 2024 09:42:03 +0800 Subject: [PATCH 66/67] Ignore E704 after update black (#7422) Fixes #7421 ### Description https://pypi.org/project/black/24.1.1/ ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mark Graham --- monai/apps/detection/utils/anchor_utils.py | 8 ++++++-- monai/data/decathlon_datalist.py | 6 ++---- monai/losses/image_dissimilarity.py | 4 +--- monai/transforms/utility/dictionary.py | 6 +++--- monai/utils/dist.py | 9 +++------ monai/utils/misc.py | 6 ++---- setup.cfg | 2 ++ tests/test_hilbert_transform.py | 20 +++++++++++--------- tests/test_spacing.py | 8 +++++--- 9 files changed, 35 insertions(+), 34 deletions(-) diff --git a/monai/apps/detection/utils/anchor_utils.py b/monai/apps/detection/utils/anchor_utils.py index baaa7ce874..283169b653 100644 --- a/monai/apps/detection/utils/anchor_utils.py +++ b/monai/apps/detection/utils/anchor_utils.py @@ -369,8 +369,12 @@ class AnchorGeneratorWithAnchorShape(AnchorGenerator): def __init__( self, feature_map_scales: Sequence[int] | Sequence[float] = (1, 2, 4, 8), - base_anchor_shapes: Sequence[Sequence[int]] - | Sequence[Sequence[float]] = ((32, 32, 32), (48, 20, 20), (20, 48, 20), (20, 20, 48)), + base_anchor_shapes: Sequence[Sequence[int]] | Sequence[Sequence[float]] = ( + (32, 32, 32), + (48, 20, 20), + (20, 48, 20), + (20, 20, 48), + ), indexing: str = "ij", ) -> None: nn.Module.__init__(self) diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py index 6f163f972e..14765dcfaa 100644 --- a/monai/data/decathlon_datalist.py +++ b/monai/data/decathlon_datalist.py @@ -24,13 +24,11 @@ @overload -def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: - ... +def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: ... @overload -def _compute_path(base_dir: PathLike, element: list[PathLike], check_path: bool = False) -> list[str]: - ... +def _compute_path(base_dir: PathLike, element: list[PathLike], check_path: bool = False) -> list[str]: ... def _compute_path(base_dir, element, check_path=False): diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 39219e059a..dd132770ec 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -277,9 +277,7 @@ def parzen_windowing_b_spline(self, img: torch.Tensor, order: int) -> tuple[torc if order == 0: weight = weight + (sample_bin_matrix < 0.5) + (sample_bin_matrix == 0.5) * 0.5 elif order == 3: - weight = ( - weight + (4 - 6 * sample_bin_matrix**2 + 3 * sample_bin_matrix**3) * (sample_bin_matrix < 1) / 6 - ) + weight = weight + (4 - 6 * sample_bin_matrix**2 + 3 * sample_bin_matrix**3) * (sample_bin_matrix < 1) / 6 weight = weight + (2 - sample_bin_matrix) ** 3 * (sample_bin_matrix >= 1) * (sample_bin_matrix < 2) / 6 else: raise ValueError(f"Do not support b-spline {order}-order parzen windowing") diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index ec10bd8537..1cd9ff6323 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1765,9 +1765,9 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N LabelToMaskD = LabelToMaskDict = LabelToMaskd FgBgToIndicesD = FgBgToIndicesDict = FgBgToIndicesd ClassesToIndicesD = ClassesToIndicesDict = ClassesToIndicesd -ConvertToMultiChannelBasedOnBratsClassesD = ( - ConvertToMultiChannelBasedOnBratsClassesDict -) = ConvertToMultiChannelBasedOnBratsClassesd +ConvertToMultiChannelBasedOnBratsClassesD = ConvertToMultiChannelBasedOnBratsClassesDict = ( + ConvertToMultiChannelBasedOnBratsClassesd +) AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld TorchVisionD = TorchVisionDict = TorchVisiond RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond diff --git a/monai/utils/dist.py b/monai/utils/dist.py index 20f09628ac..2418b43591 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -50,18 +50,15 @@ def get_dist_device(): @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[True]) -> torch.Tensor: - ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[True]) -> torch.Tensor: ... @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[False]) -> list[torch.Tensor]: - ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[False]) -> list[torch.Tensor]: ... @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: bool) -> torch.Tensor | list[torch.Tensor]: - ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: bool) -> torch.Tensor | list[torch.Tensor]: ... def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = True) -> torch.Tensor | list[torch.Tensor]: diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 4f2501a7ee..81f582daef 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -103,13 +103,11 @@ def star_zip_with(op, *vals): @overload -def first(iterable: Iterable[T], default: T) -> T: - ... +def first(iterable: Iterable[T], default: T) -> T: ... @overload -def first(iterable: Iterable[T]) -> T | None: - ... +def first(iterable: Iterable[T]) -> T | None: ... def first(iterable: Iterable[T], default: T | None = None) -> T | None: diff --git a/setup.cfg b/setup.cfg index 0069214de3..4180ced917 100644 --- a/setup.cfg +++ b/setup.cfg @@ -174,6 +174,7 @@ max_line_length = 120 # B907 https://github.com/Project-MONAI/MONAI/issues/5868 # B908 https://github.com/Project-MONAI/MONAI/issues/6503 # B036 https://github.com/Project-MONAI/MONAI/issues/7396 +# E704 https://github.com/Project-MONAI/MONAI/issues/7421 ignore = E203 E501 @@ -188,6 +189,7 @@ ignore = B907 B908 B036 + E704 per_file_ignores = __init__.py: F401, __main__.py: F401 exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,.venv,_version.py diff --git a/tests/test_hilbert_transform.py b/tests/test_hilbert_transform.py index 4c49aecd8b..68fa0b1192 100644 --- a/tests/test_hilbert_transform.py +++ b/tests/test_hilbert_transform.py @@ -180,15 +180,17 @@ def test_value(self, arguments, image, expected_data, atol): @SkipIfNoModule("torch.fft") class TestHilbertTransformGPU(unittest.TestCase): @parameterized.expand( - [] - if not torch.cuda.is_available() - else [ - TEST_CASE_1D_SINE_GPU, - TEST_CASE_2D_SINE_GPU, - TEST_CASE_3D_SINE_GPU, - TEST_CASE_1D_2CH_SINE_GPU, - TEST_CASE_2D_2CH_SINE_GPU, - ], + ( + [] + if not torch.cuda.is_available() + else [ + TEST_CASE_1D_SINE_GPU, + TEST_CASE_2D_SINE_GPU, + TEST_CASE_3D_SINE_GPU, + TEST_CASE_1D_2CH_SINE_GPU, + TEST_CASE_2D_2CH_SINE_GPU, + ] + ), skip_on_empty=True, ) def test_value(self, arguments, image, expected_data, atol): diff --git a/tests/test_spacing.py b/tests/test_spacing.py index 1ff1518297..8b664641d7 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -74,9 +74,11 @@ torch.ones((1, 2, 1, 2)), # data torch.tensor([[2, 1, 0, 4], [-1, -3, 0, 5], [0, 0, 2.0, 5], [0, 0, 0, 1]]), {}, - torch.tensor([[[[0.75, 0.75]], [[0.75, 0.75]], [[0.75, 0.75]]]]) - if USE_COMPILED - else torch.tensor([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]]), + ( + torch.tensor([[[[0.75, 0.75]], [[0.75, 0.75]], [[0.75, 0.75]]]]) + if USE_COMPILED + else torch.tensor([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]]) + ), *device, ] ) From 6faf4b202598c4de1f902eae1430b1434f11b363 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 30 Jan 2024 11:34:07 +0000 Subject: [PATCH 67/67] new formatting Signed-off-by: Mark Graham --- monai/networks/nets/vqvae.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/vqvae.py b/monai/networks/nets/vqvae.py index d4771e203a..f198bfbb2b 100644 --- a/monai/networks/nets/vqvae.py +++ b/monai/networks/nets/vqvae.py @@ -312,10 +312,16 @@ def __init__( channels: Sequence[int] = (96, 96, 192), num_res_layers: int = 3, num_res_channels: Sequence[int] | int = (96, 96, 192), - downsample_parameters: Sequence[Tuple[int, int, int, int]] - | Tuple[int, int, int, int] = ((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1)), - upsample_parameters: Sequence[Tuple[int, int, int, int, int]] - | Tuple[int, int, int, int, int] = ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + downsample_parameters: Sequence[Tuple[int, int, int, int]] | Tuple[int, int, int, int] = ( + (2, 4, 1, 1), + (2, 4, 1, 1), + (2, 4, 1, 1), + ), + upsample_parameters: Sequence[Tuple[int, int, int, int, int]] | Tuple[int, int, int, int, int] = ( + (2, 4, 1, 1, 0), + (2, 4, 1, 1, 0), + (2, 4, 1, 1, 0), + ), num_embeddings: int = 32, embedding_dim: int = 64, embedding_init: str = "normal",