diff --git a/.github/workflows/chatops.yml b/.github/workflows/chatops.yml index b4e201a0d9..59c7d070b4 100644 --- a/.github/workflows/chatops.yml +++ b/.github/workflows/chatops.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest steps: - name: dispatch - uses: peter-evans/slash-command-dispatch@v3.0.1 + uses: peter-evans/slash-command-dispatch@v3.0.2 with: token: ${{ secrets.PR_MAINTAIN }} reaction-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 3d32ae407a..18f1519b5a 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -42,7 +42,7 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -72,4 +72,4 @@ jobs: BUILD_MONAI=1 ./runtests.sh --build - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 diff --git a/.github/workflows/cron-ngc-bundle.yml b/.github/workflows/cron-ngc-bundle.yml index 0bba630d03..84666204a9 100644 --- a/.github/workflows/cron-ngc-bundle.yml +++ b/.github/workflows/cron-ngc-bundle.yml @@ -19,7 +19,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index f51e4fdf76..229ae675f5 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -26,7 +26,7 @@ jobs: ref: dev fetch-depth: 0 - name: Set up Python 3.9 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.9' - shell: bash @@ -36,7 +36,7 @@ jobs: python setup.py build cat build/lib/monai/_version.py - name: Upload version - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: _version.py path: build/lib/monai/_version.py @@ -56,7 +56,7 @@ jobs: with: ref: dev - name: Download version - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: _version.py - name: docker_build diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml index 558c270e33..7b7930bdf5 100644 --- a/.github/workflows/pythonapp-min.yml +++ b/.github/workflows/pythonapp-min.yml @@ -30,7 +30,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: Prepare pip wheel @@ -76,7 +76,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Prepare pip wheel @@ -121,7 +121,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: Prepare pip wheel diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index ad8b555dd4..29a79759e0 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -28,7 +28,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp @@ -69,7 +69,7 @@ jobs: disk-root: "D:" - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: Prepare pip wheel @@ -128,7 +128,7 @@ jobs: with: fetch-depth: 0 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp @@ -209,7 +209,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7197215486..a03d2cea6c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -19,7 +19,7 @@ jobs: with: fetch-depth: 0 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install setuptools @@ -66,7 +66,7 @@ jobs: - if: matrix.python-version == '3.9' && startsWith(github.ref, 'refs/tags/') name: Upload artifacts - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: dist path: dist/ @@ -97,7 +97,7 @@ jobs: with: fetch-depth: 0 - name: Set up Python 3.9 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.9' - shell: bash @@ -108,7 +108,7 @@ jobs: python setup.py build cat build/lib/monai/_version.py - name: Upload version - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: _version.py path: build/lib/monai/_version.py @@ -125,7 +125,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Download version - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: _version.py - name: Set tag diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index 0ff7162bee..82394a86dd 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -83,7 +83,7 @@ jobs: with: fetch-depth: 0 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: cache weekly timestamp @@ -120,7 +120,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Set up Python 3.8 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.8' - name: cache weekly timestamp diff --git a/.github/workflows/weekly-preview.yml b/.github/workflows/weekly-preview.yml index c631982745..e94e1dac5a 100644 --- a/.github/workflows/weekly-preview.yml +++ b/.github/workflows/weekly-preview.yml @@ -14,7 +14,7 @@ jobs: ref: dev fetch-depth: 0 - name: Set up Python 3.9 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.9' - name: Install setuptools diff --git a/docs/requirements.txt b/docs/requirements.txt index a9bbc384f8..e5bedf8552 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -21,7 +21,7 @@ sphinxcontrib-serializinghtml sphinx-autodoc-typehints==1.11.1 pandas einops -transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157 +transformers<4.22; python_version <= '3.10' # https://github.com/Project-MONAI/MONAI/issues/5157 mlflow>=1.28.0 clearml>=1.10.0rc0 tensorboardX diff --git a/docs/source/engines.rst b/docs/source/engines.rst index afb2682822..a015c7b2a3 100644 --- a/docs/source/engines.rst +++ b/docs/source/engines.rst @@ -30,6 +30,11 @@ Workflows .. autoclass:: GanTrainer :members: +`AdversarialTrainer` +~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: AdversarialTrainer + :members: + `Evaluator` ~~~~~~~~~~~ .. autoclass:: Evaluator diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 568c7dfc77..e929e9d605 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -96,6 +96,11 @@ Registration Losses .. autoclass:: BendingEnergyLoss :members: +`DiffusionLoss` +~~~~~~~~~~~~~~~ +.. autoclass:: DiffusionLoss + :members: + `LocalNormalizedCrossCorrelationLoss` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: LocalNormalizedCrossCorrelationLoss diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py index 9280fb5be5..15e56abfea 100644 --- a/monai/apps/auto3dseg/data_analyzer.py +++ b/monai/apps/auto3dseg/data_analyzer.py @@ -28,7 +28,7 @@ from monai.data import DataLoader, Dataset, partition_dataset from monai.data.utils import no_collation from monai.transforms import Compose, EnsureTyped, LoadImaged, Orientationd -from monai.utils import StrEnum, min_version, optional_import +from monai.utils import ImageMetaKey, StrEnum, min_version, optional_import from monai.utils.enums import DataStatsKeys, ImageStatsKeys @@ -343,19 +343,25 @@ def _get_all_case_stats( d = summarizer(batch_data) except BaseException as err: if "image_meta_dict" in batch_data.keys(): - filename = batch_data["image_meta_dict"]["filename_or_obj"] + filename = batch_data["image_meta_dict"][ImageMetaKey.FILENAME_OR_OBJ] else: - filename = batch_data[self.image_key].meta["filename_or_obj"] + filename = batch_data[self.image_key].meta[ImageMetaKey.FILENAME_OR_OBJ] logger.info(f"Unable to process data {filename} on {device}. {err}") if self.device.type == "cuda": logger.info("DataAnalyzer `device` set to GPU execution hit an exception. Falling back to `cpu`.") - batch_data[self.image_key] = batch_data[self.image_key].to("cpu") - if self.label_key is not None: - label = batch_data[self.label_key] - if not _label_argmax: - label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0] - batch_data[self.label_key] = label.to("cpu") - d = summarizer(batch_data) + try: + batch_data[self.image_key] = batch_data[self.image_key].to("cpu") + if self.label_key is not None: + label = batch_data[self.label_key] + if not _label_argmax: + label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0] + batch_data[self.label_key] = label.to("cpu") + d = summarizer(batch_data) + except BaseException as err: + logger.info(f"Unable to process data {filename} on {device}. {err}") + continue + else: + continue stats_by_cases = { DataStatsKeys.BY_CASE_IMAGE_PATH: d[DataStatsKeys.BY_CASE_IMAGE_PATH], diff --git a/monai/apps/detection/utils/anchor_utils.py b/monai/apps/detection/utils/anchor_utils.py index baaa7ce874..283169b653 100644 --- a/monai/apps/detection/utils/anchor_utils.py +++ b/monai/apps/detection/utils/anchor_utils.py @@ -369,8 +369,12 @@ class AnchorGeneratorWithAnchorShape(AnchorGenerator): def __init__( self, feature_map_scales: Sequence[int] | Sequence[float] = (1, 2, 4, 8), - base_anchor_shapes: Sequence[Sequence[int]] - | Sequence[Sequence[float]] = ((32, 32, 32), (48, 20, 20), (20, 48, 20), (20, 20, 48)), + base_anchor_shapes: Sequence[Sequence[int]] | Sequence[Sequence[float]] = ( + (32, 32, 32), + (48, 20, 20), + (20, 48, 20), + (20, 20, 48), + ), indexing: str = "ij", ) -> None: nn.Module.__init__(self) diff --git a/monai/apps/utils.py b/monai/apps/utils.py index d2dd63b958..442dbabba0 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -30,7 +30,7 @@ from monai.config.type_definitions import PathLike from monai.utils import look_up_option, min_version, optional_import -gdown, has_gdown = optional_import("gdown", "4.4") +gdown, has_gdown = optional_import("gdown", "4.6.3") if TYPE_CHECKING: from tqdm import tqdm diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py index 654999d439..56419da4cb 100644 --- a/monai/auto3dseg/analyzer.py +++ b/monai/auto3dseg/analyzer.py @@ -256,7 +256,7 @@ def __call__(self, data): ) report[ImageStatsKeys.SIZEMM] = [ - int(a * b) for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING]) + a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING]) ] report[ImageStatsKeys.INTENSITY] = [ @@ -460,7 +460,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe torch.set_grad_enabled(False) ndas: list[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] # type: ignore - ndas_label: MetaTensor = d[self.label_key] # (H,W,D) + ndas_label: MetaTensor = d[self.label_key].astype(torch.int8) # (H,W,D) if ndas_label.shape != ndas[0].shape: raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py index 6f163f972e..14765dcfaa 100644 --- a/monai/data/decathlon_datalist.py +++ b/monai/data/decathlon_datalist.py @@ -24,13 +24,11 @@ @overload -def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: - ... +def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: ... @overload -def _compute_path(base_dir: PathLike, element: list[PathLike], check_path: bool = False) -> list[str]: - ... +def _compute_path(base_dir: PathLike, element: list[PathLike], check_path: bool = False) -> list[str]: ... def _compute_path(base_dir, element, check_path=False): diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 0823d11834..2361bb63a7 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -168,8 +168,8 @@ class ITKReader(ImageReader): series_name: the name of the DICOM series if there are multiple ones. used when loading DICOM series. reverse_indexing: whether to use a reversed spatial indexing convention for the returned data array. - If ``False``, the spatial indexing follows the numpy convention; - otherwise, the spatial indexing convention is reversed to be compatible with ITK. Default is ``False``. + If ``False``, the spatial indexing convention is reversed to be compatible with ITK; + otherwise, the spatial indexing follows the numpy convention. Default is ``False``. This option does not affect the metadata. series_meta: whether to load the metadata of the DICOM series (using the metadata from the first slice). This flag is checked only when loading DICOM series. Default is ``False``. @@ -1323,7 +1323,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: header = dict(i.header) if self.index_order == "C": header = self._convert_f_to_c_order(header) - header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i) + header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(header) if self.affine_lps_to_ras: header = self._switch_lps_ras(header) @@ -1344,7 +1344,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: return _stack_images(img_array, compatible_meta), compatible_meta - def _get_affine(self, img: NrrdImage) -> np.ndarray: + def _get_affine(self, header: dict) -> np.ndarray: """ Get the affine matrix of the image, it can be used to correct spacing, orientation or execute spatial transforms. @@ -1353,8 +1353,8 @@ def _get_affine(self, img: NrrdImage) -> np.ndarray: img: A `NrrdImage` loaded from image file """ - direction = img.header["space directions"] - origin = img.header["space origin"] + direction = header["space directions"] + origin = header["space origin"] x, y = direction.shape affine_diam = min(x, y) + 1 diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index d8dc51f620..93cc40e292 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -12,12 +12,14 @@ from __future__ import annotations from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator -from .trainer import GanTrainer, SupervisedTrainer, Trainer +from .trainer import AdversarialTrainer, GanTrainer, SupervisedTrainer, Trainer from .utils import ( + DiffusionPrepareBatch, IterationEvents, PrepareBatch, PrepareBatchDefault, PrepareBatchExtraInput, + VPredictionPrepareBatch, default_make_latent, default_metric_cmp_fn, default_prepare_batch, diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 119853d5c5..2c8dfe6b85 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -11,12 +11,14 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import torch from torch.utils.data import DataLoader from monai.config import IgniteInfo, KeysCollection +from monai.data import MetaTensor from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer @@ -25,7 +27,7 @@ from monai.utils import ForwardMode, ensure_tuple, min_version, optional_import from monai.utils.enums import CommonKeys as Keys from monai.utils.enums import EngineStatsKeys as ESKeys -from monai.utils.module import look_up_option +from monai.utils.module import look_up_option, pytorch_after if TYPE_CHECKING: from ignite.engine import Engine, EventEnum @@ -213,6 +215,10 @@ class SupervisedEvaluator(Evaluator): `device`, `non_blocking`. amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to + `torch.Tensor` before forward pass, then converted back afterward with copied meta information. + compile_kwargs: dict of the args for `torch.compile()` API, for more details: + https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile. """ @@ -238,6 +244,8 @@ def __init__( decollate: bool = True, to_kwargs: dict | None = None, amp_kwargs: dict | None = None, + compile: bool = False, + compile_kwargs: dict | None = None, ) -> None: super().__init__( device=device, @@ -259,8 +267,16 @@ def __init__( to_kwargs=to_kwargs, amp_kwargs=amp_kwargs, ) - + if compile: + if pytorch_after(2, 1): + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + network = torch.compile(network, **compile_kwargs) # type: ignore[assignment] + else: + warnings.warn( + "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done" + ) self.network = network + self.compile = compile self.inferer = SimpleInferer() if inferer is None else inferer def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Tensor]) -> dict: @@ -288,6 +304,24 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten kwargs: dict = {} else: inputs, targets, args, kwargs = batch + # FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026 + if self.compile: + inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None + if isinstance(inputs, MetaTensor): + warnings.warn( + "Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass." + ) + inputs, inputs_meta, inputs_applied_operations = ( + inputs.as_tensor(), + inputs.meta, + inputs.applied_operations, + ) + if isinstance(targets, MetaTensor): + targets, targets_meta, targets_applied_operations = ( + targets.as_tensor(), + targets.meta, + targets.applied_operations, + ) # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} @@ -298,6 +332,19 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) else: engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) + # copy back meta info + if self.compile: + if inputs_meta is not None: + engine.state.output[Keys.IMAGE] = MetaTensor( + inputs, meta=inputs_meta, applied_operations=inputs_applied_operations + ) + engine.state.output[Keys.PRED] = MetaTensor( + engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations + ) + if targets_meta is not None: + engine.state.output[Keys.LABEL] = MetaTensor( + targets, meta=targets_meta, applied_operations=targets_applied_operations + ) engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 61b7028e11..c1364fe015 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -11,6 +11,7 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import torch @@ -18,13 +19,15 @@ from torch.utils.data import DataLoader from monai.config import IgniteInfo +from monai.data import MetaTensor from monai.engines.utils import IterationEvents, default_make_latent, default_metric_cmp_fn, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform -from monai.utils import GanKeys, min_version, optional_import +from monai.utils import AdversarialIterationEvents, AdversarialKeys, GanKeys, min_version, optional_import from monai.utils.enums import CommonKeys as Keys from monai.utils.enums import EngineStatsKeys as ESKeys +from monai.utils.module import pytorch_after if TYPE_CHECKING: from ignite.engine import Engine, EventEnum @@ -34,7 +37,7 @@ Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric") EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum") -__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer"] +__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer", "AdversarialTrainer"] class Trainer(Workflow): @@ -125,7 +128,10 @@ class SupervisedTrainer(Trainer): `device`, `non_blocking`. amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. - + compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to + `torch.Tensor` before forward pass, then converted back afterward with copied meta information. + compile_kwargs: dict of the args for `torch.compile()` API, for more details: + https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile. """ def __init__( @@ -153,6 +159,8 @@ def __init__( optim_set_to_none: bool = False, to_kwargs: dict | None = None, amp_kwargs: dict | None = None, + compile: bool = False, + compile_kwargs: dict | None = None, ) -> None: super().__init__( device=device, @@ -174,8 +182,16 @@ def __init__( to_kwargs=to_kwargs, amp_kwargs=amp_kwargs, ) - + if compile: + if pytorch_after(2, 1): + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + network = torch.compile(network, **compile_kwargs) # type: ignore[assignment] + else: + warnings.warn( + "Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done" + ) self.network = network + self.compile = compile self.optimizer = optimizer self.loss_function = loss_function self.inferer = SimpleInferer() if inferer is None else inferer @@ -207,6 +223,25 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso kwargs: dict = {} else: inputs, targets, args, kwargs = batch + # FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026 + if self.compile: + inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None + if isinstance(inputs, MetaTensor): + warnings.warn( + "Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass." + ) + inputs, inputs_meta, inputs_applied_operations = ( + inputs.as_tensor(), + inputs.meta, + inputs.applied_operations, + ) + if isinstance(targets, MetaTensor): + targets, targets_meta, targets_applied_operations = ( + targets.as_tensor(), + targets.meta, + targets.applied_operations, + ) + # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} @@ -231,6 +266,19 @@ def _compute_pred_loss(): engine.state.output[Keys.LOSS].backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) engine.optimizer.step() + # copy back meta info + if self.compile: + if inputs_meta is not None: + engine.state.output[Keys.IMAGE] = MetaTensor( + inputs, meta=inputs_meta, applied_operations=inputs_applied_operations + ) + engine.state.output[Keys.PRED] = MetaTensor( + engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations + ) + if targets_meta is not None: + engine.state.output[Keys.LABEL] = MetaTensor( + targets, meta=targets_meta, applied_operations=targets_applied_operations + ) engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output @@ -423,3 +471,282 @@ def _iteration( GanKeys.GLOSS: g_loss.item(), GanKeys.DLOSS: d_total_loss.item(), } + + +class AdversarialTrainer(Trainer): + """ + Standard supervised training workflow for adversarial loss enabled neural networks. + + Args: + device: an object representing the device on which to run. + max_epochs: the total epoch number for engine to run. + train_data_loader: Core ignite engines uses `DataLoader` for training loop batchdata. + g_network: ''generator'' (G) network architecture. + g_optimizer: G optimizer function. + g_loss_function: G loss function for adversarial training. + recon_loss_function: G loss function for reconstructions. + d_network: discriminator (D) network architecture. + d_optimizer: D optimizer function. + d_loss_function: D loss function for adversarial training.. + epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`. + non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to + the host. For other cases, this argument has no effect. + prepare_batch: function to parse image and label for current iteration. + iteration_update: the callable function for every iteration, expect to accept `engine` and `batchdata` as input + parameters. if not provided, use `self._iteration()` instead. + g_inferer: inference method to execute G model forward. Defaults to ``SimpleInferer()``. + d_inferer: inference method to execute D model forward. Defaults to ``SimpleInferer()``. + postprocessing: execute additional transformation for the model output data. Typically, several Tensor based + transforms composed by `Compose`. Defaults to None + key_train_metric: compute metric when every iteration completed, and save average value to engine.state.metrics + when epoch completed. key_train_metric is the main metric to compare and save the checkpoint into files. + additional_metrics: more Ignite metrics that also attach to Ignite Engine. + metric_cmp_fn: function to compare current key metric with previous best key metric value, it must accept 2 args + (current_metric, previous_best) and return a bool result: if `True`, will update 'best_metric` and + `best_metric_epoch` with current metric and epoch, default to `greater than`. + train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: + CheckpointHandler, StatsHandler, etc. + amp: whether to enable auto-mixed-precision training, default is False. + event_names: additional custom ignite events that will register to the engine. + new events can be a list of str or `ignite.engine.events.EventEnum`. + event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html + #ignite.engine.engine.Engine.register_events. + decollate: whether to decollate the batch-first data to a list of data after model computation, recommend + `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. + optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None. + more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. + amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + """ + + def __init__( + self, + device: torch.device | str, + max_epochs: int, + train_data_loader: Iterable | DataLoader, + g_network: torch.nn.Module, + g_optimizer: Optimizer, + g_loss_function: Callable, + recon_loss_function: Callable, + d_network: torch.nn.Module, + d_optimizer: Optimizer, + d_loss_function: Callable, + epoch_length: int | None = None, + non_blocking: bool = False, + prepare_batch: Callable = default_prepare_batch, + iteration_update: Callable | None = None, + g_inferer: Inferer | None = None, + d_inferer: Inferer | None = None, + postprocessing: Transform | None = None, + key_train_metric: dict[str, Metric] | None = None, + additional_metrics: dict[str, Metric] | None = None, + metric_cmp_fn: Callable = default_metric_cmp_fn, + train_handlers: Sequence | None = None, + amp: bool = False, + event_names: list[str | EventEnum | type[EventEnum]] | None = None, + event_to_attr: dict | None = None, + decollate: bool = True, + optim_set_to_none: bool = False, + to_kwargs: dict | None = None, + amp_kwargs: dict | None = None, + ): + super().__init__( + device=device, + max_epochs=max_epochs, + data_loader=train_data_loader, + epoch_length=epoch_length, + non_blocking=non_blocking, + prepare_batch=prepare_batch, + iteration_update=iteration_update, + postprocessing=postprocessing, + key_metric=key_train_metric, + additional_metrics=additional_metrics, + metric_cmp_fn=metric_cmp_fn, + handlers=train_handlers, + amp=amp, + event_names=event_names, + event_to_attr=event_to_attr, + decollate=decollate, + to_kwargs=to_kwargs, + amp_kwargs=amp_kwargs, + ) + + self.register_events(*AdversarialIterationEvents) + + self.state.g_network = g_network + self.state.g_optimizer = g_optimizer + self.state.g_loss_function = g_loss_function + self.state.recon_loss_function = recon_loss_function + + self.state.d_network = d_network + self.state.d_optimizer = d_optimizer + self.state.d_loss_function = d_loss_function + + self.g_inferer = SimpleInferer() if g_inferer is None else g_inferer + self.d_inferer = SimpleInferer() if d_inferer is None else d_inferer + + self.state.g_scaler = torch.cuda.amp.GradScaler() if self.amp else None + self.state.d_scaler = torch.cuda.amp.GradScaler() if self.amp else None + + self.optim_set_to_none = optim_set_to_none + self._complete_state_dict_user_keys() + + def _complete_state_dict_user_keys(self) -> None: + """ + This method appends to the _state_dict_user_keys AdversarialTrainer's elements that are required for + checkpoint saving. + + Follows the example found at: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine.state_dict + """ + self._state_dict_user_keys.extend( + ["g_network", "g_optimizer", "d_network", "d_optimizer", "g_scaler", "d_scaler"] + ) + + g_loss_state_dict = getattr(self.state.g_loss_function, "state_dict", None) + if callable(g_loss_state_dict): + self._state_dict_user_keys.append("g_loss_function") + + d_loss_state_dict = getattr(self.state.d_loss_function, "state_dict", None) + if callable(d_loss_state_dict): + self._state_dict_user_keys.append("d_loss_function") + + recon_loss_state_dict = getattr(self.state.recon_loss_function, "state_dict", None) + if callable(recon_loss_state_dict): + self._state_dict_user_keys.append("recon_loss_function") + + def _iteration( + self, engine: AdversarialTrainer, batchdata: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor | int | float | bool]: + """ + Callback function for the Adversarial Training processing logic of 1 iteration in Ignite Engine. + Return below items in a dictionary: + - IMAGE: image Tensor data for model input, already moved to device. + - LABEL: label Tensor data corresponding to the image, already moved to device. In case of Unsupervised + Learning this is equal to IMAGE. + - PRED: prediction result of model. + - LOSS: loss value computed by loss functions of the generator (reconstruction and adversarial summed up). + - AdversarialKeys.REALS: real images from the batch. Are the same as IMAGE. + - AdversarialKeys.FAKES: fake images generated by the generator. Are the same as PRED. + - AdversarialKeys.REAL_LOGITS: logits of the discriminator for the real images. + - AdversarialKeys.FAKE_LOGITS: logits of the discriminator for the fake images. + - AdversarialKeys.RECONSTRUCTION_LOSS: loss value computed by the reconstruction loss function. + - AdversarialKeys.GENERATOR_LOSS: loss value computed by the generator loss function. It is the + discriminator loss for the fake images. That is backpropagated through the generator only. + - AdversarialKeys.DISCRIMINATOR_LOSS: loss value computed by the discriminator loss function. It is the + discriminator loss for the real images and the fake images. That is backpropagated through the + discriminator only. + + Args: + engine: `AdversarialTrainer` to execute operation for an iteration. + batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. + + Raises: + ValueError: must provide batch data for current iteration. + + """ + + if batchdata is None: + raise ValueError("Must provide batch data for current iteration.") + batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) + + if len(batch) == 2: + inputs, targets = batch + args: tuple = () + kwargs: dict = {} + else: + inputs, targets, args, kwargs = batch + + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets, AdversarialKeys.REALS: inputs} + + def _compute_generator_loss() -> None: + engine.state.output[AdversarialKeys.FAKES] = engine.g_inferer( + inputs, engine.state.g_network, *args, **kwargs + ) + engine.state.output[Keys.PRED] = engine.state.output[AdversarialKeys.FAKES] + engine.fire_event(AdversarialIterationEvents.GENERATOR_FORWARD_COMPLETED) + + engine.state.output[AdversarialKeys.FAKE_LOGITS] = engine.d_inferer( + engine.state.output[AdversarialKeys.FAKES].float().contiguous(), engine.state.d_network, *args, **kwargs + ) + engine.fire_event(AdversarialIterationEvents.GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED) + + engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS] = engine.state.recon_loss_function( + engine.state.output[AdversarialKeys.FAKES], targets + ).mean() + engine.fire_event(AdversarialIterationEvents.RECONSTRUCTION_LOSS_COMPLETED) + + engine.state.output[AdversarialKeys.GENERATOR_LOSS] = engine.state.g_loss_function( + engine.state.output[AdversarialKeys.FAKE_LOGITS] + ).mean() + engine.fire_event(AdversarialIterationEvents.GENERATOR_LOSS_COMPLETED) + + # Train Generator + engine.state.g_network.train() + engine.state.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none) + + if engine.amp and engine.state.g_scaler is not None: + with torch.cuda.amp.autocast(**engine.amp_kwargs): + _compute_generator_loss() + + engine.state.output[Keys.LOSS] = ( + engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS] + + engine.state.output[AdversarialKeys.GENERATOR_LOSS] + ) + engine.state.g_scaler.scale(engine.state.output[Keys.LOSS]).backward() + engine.fire_event(AdversarialIterationEvents.GENERATOR_BACKWARD_COMPLETED) + engine.state.g_scaler.step(engine.state.g_optimizer) + engine.state.g_scaler.update() + else: + _compute_generator_loss() + ( + engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS] + + engine.state.output[AdversarialKeys.GENERATOR_LOSS] + ).backward() + engine.fire_event(AdversarialIterationEvents.GENERATOR_BACKWARD_COMPLETED) + engine.state.g_optimizer.step() + engine.fire_event(AdversarialIterationEvents.GENERATOR_MODEL_COMPLETED) + + def _compute_discriminator_loss() -> None: + engine.state.output[AdversarialKeys.REAL_LOGITS] = engine.d_inferer( + engine.state.output[AdversarialKeys.REALS].contiguous().detach(), + engine.state.d_network, + *args, + **kwargs, + ) + engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_REALS_FORWARD_COMPLETED) + + engine.state.output[AdversarialKeys.FAKE_LOGITS] = engine.d_inferer( + engine.state.output[AdversarialKeys.FAKES].contiguous().detach(), + engine.state.d_network, + *args, + **kwargs, + ) + engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_FAKES_FORWARD_COMPLETED) + + engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS] = engine.state.d_loss_function( + engine.state.output[AdversarialKeys.REAL_LOGITS], engine.state.output[AdversarialKeys.FAKE_LOGITS] + ).mean() + engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_LOSS_COMPLETED) + + # Train Discriminator + engine.state.d_network.train() + engine.state.d_network.zero_grad(set_to_none=engine.optim_set_to_none) + + if engine.amp and engine.state.d_scaler is not None: + with torch.cuda.amp.autocast(**engine.amp_kwargs): + _compute_discriminator_loss() + + engine.state.d_scaler.scale(engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS]).backward() + engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_BACKWARD_COMPLETED) + engine.state.d_scaler.step(engine.state.d_optimizer) + engine.state.d_scaler.update() + else: + _compute_discriminator_loss() + engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS].backward() + engine.state.d_optimizer.step() + + return engine.state.output diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 02c718cd14..5339d6965a 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -13,9 +13,10 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Mapping, cast import torch +import torch.nn as nn from monai.config import IgniteInfo from monai.transforms import apply_transform @@ -36,6 +37,8 @@ "PrepareBatch", "PrepareBatchDefault", "PrepareBatchExtraInput", + "DiffusionPrepareBatch", + "VPredictionPrepareBatch", "default_make_latent", "engine_apply_transform", "default_metric_cmp_fn", @@ -238,6 +241,78 @@ def _get_data(key: str) -> torch.Tensor: return cast(torch.Tensor, image), cast(torch.Tensor, label), tuple(args_), kwargs_ +class DiffusionPrepareBatch(PrepareBatch): + """ + This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. + + Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and + return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise". + This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided. + + If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition + field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". + + """ + + def __init__(self, num_train_timesteps: int, condition_name: str | None = None) -> None: + self.condition_name = condition_name + self.num_train_timesteps = num_train_timesteps + + def get_noise(self, images: torch.Tensor) -> torch.Tensor: + """Returns the noise tensor for input tensor `images`, override this for different noise distributions.""" + return torch.randn_like(images) + + def get_timesteps(self, images: torch.Tensor) -> torch.Tensor: + """Get a timestep, by default this is a random integer between 0 and `self.num_train_timesteps`.""" + return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long() + + def get_target(self, images: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + """Return the target for the loss function, this is the `noise` value by default.""" + return noise + + def __call__( + self, + batchdata: dict[str, torch.Tensor], + device: str | torch.device | None = None, + non_blocking: bool = False, + **kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor, tuple, dict]: + images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs) + noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs) + timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs) + + target = self.get_target(images, noise, timesteps).to(device, non_blocking=non_blocking, **kwargs) + infer_kwargs = {"noise": noise, "timesteps": timesteps} + + if self.condition_name is not None and isinstance(batchdata, Mapping): + infer_kwargs["condition"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs) + + # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value + return images, target, (), infer_kwargs + + +class VPredictionPrepareBatch(DiffusionPrepareBatch): + """ + This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. + + Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and + from this compute the velocity using the provided scheduler. This value is used as the target in place of the + noise field itself although the noise is field is in the kwargs under the key "noise". This assumes the inferer + being used in conjunction with this class expects a "noise" parameter to be provided. + + If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition + field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". + + """ + + def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_name: str | None = None) -> None: + super().__init__(num_train_timesteps=num_train_timesteps, condition_name=condition_name) + self.scheduler = scheduler + + def get_target(self, images, noise, timesteps): + return self.scheduler.get_velocity(images, noise, timesteps) + + def default_make_latent( num_latents: int, latent_size: int, diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py index a2bd345dc6..df209c1c8b 100644 --- a/monai/handlers/mlflow_handler.py +++ b/monai/handlers/mlflow_handler.py @@ -401,7 +401,7 @@ def _default_iteration_log(self, engine: Engine) -> None: cur_optimizer = engine.optimizer for param_name in self.optimizer_param_names: params = { - f"{param_name} group_{i}": float(param_group[param_name]) + f"{param_name}_group_{i}": float(param_group[param_name]) for i, param_group in enumerate(cur_optimizer.param_groups) } self._log_metrics(params, step=engine.state.iteration) diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index d734a9d44d..92898c81ca 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -14,7 +14,7 @@ from .adversarial_loss import PatchAdversarialLoss from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss from .contrastive import ContrastiveLoss -from .deform import BendingEnergyLoss +from .deform import BendingEnergyLoss, DiffusionLoss from .dice import ( Dice, DiceCELoss, diff --git a/monai/losses/deform.py b/monai/losses/deform.py index dd03a8eb3d..37e4468d4b 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -46,7 +46,10 @@ def spatial_gradient(x: torch.Tensor, dim: int) -> torch.Tensor: class BendingEnergyLoss(_Loss): """ - Calculate the bending energy based on second-order differentiation of pred using central finite difference. + Calculate the bending energy based on second-order differentiation of ``pred`` using central finite difference. + + For more information, + see https://github.com/Project-MONAI/tutorials/blob/main/modules/bending_energy_diffusion_loss_notes.ipynb. Adapted from: DeepReg (https://github.com/DeepRegNet/DeepReg) @@ -75,6 +78,9 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + ValueError: When ``pred`` is not 3-d, 4-d or 5-d. + ValueError: When any spatial dimension of ``pred`` has size less than or equal to 4. + ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions. """ if pred.ndim not in [3, 4, 5]: @@ -84,7 +90,8 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: raise ValueError(f"All spatial dimensions must be > 4, got spatial dimensions {pred.shape[2:]}") if pred.shape[1] != pred.ndim - 2: raise ValueError( - f"Number of vector components, {pred.shape[1]}, does not match number of spatial dimensions, {pred.ndim-2}" + f"Number of vector components, i.e. number of channels of the input DDF, {pred.shape[1]}, " + f"does not match number of spatial dimensions, {pred.ndim - 2}" ) # first order gradient @@ -116,3 +123,88 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return energy + + +class DiffusionLoss(_Loss): + """ + Calculate the diffusion based on first-order differentiation of ``pred`` using central finite difference. + For the original paper, please refer to + VoxelMorph: A Learning Framework for Deformable Medical Image Registration, + Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca + IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231. + + For more information, + see https://github.com/Project-MONAI/tutorials/blob/main/modules/bending_energy_diffusion_loss_notes.ipynb. + + Adapted from: + VoxelMorph (https://github.com/voxelmorph/voxelmorph) + """ + + def __init__(self, normalize: bool = False, reduction: LossReduction | str = LossReduction.MEAN) -> None: + """ + Args: + normalize: + Whether to divide out spatial sizes in order to make the computation roughly + invariant to image scale (i.e. vector field sampling resolution). Defaults to False. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + """ + super().__init__(reduction=LossReduction(reduction).value) + self.normalize = normalize + + def forward(self, pred: torch.Tensor) -> torch.Tensor: + """ + Args: + pred: + Predicted dense displacement field (DDF) with shape BCH[WD], + where C is the number of spatial dimensions. + Note that diffusion loss can only be calculated + when the sizes of the DDF along all spatial dimensions are greater than 2. + + Raises: + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + ValueError: When ``pred`` is not 3-d, 4-d or 5-d. + ValueError: When any spatial dimension of ``pred`` has size less than or equal to 2. + ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions. + + """ + if pred.ndim not in [3, 4, 5]: + raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") + for i in range(pred.ndim - 2): + if pred.shape[-i - 1] <= 2: + raise ValueError(f"All spatial dimensions must be > 2, got spatial dimensions {pred.shape[2:]}") + if pred.shape[1] != pred.ndim - 2: + raise ValueError( + f"Number of vector components, i.e. number of channels of the input DDF, {pred.shape[1]}, " + f"does not match number of spatial dimensions, {pred.ndim - 2}" + ) + + # first order gradient + first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)] + + # spatial dimensions in a shape suited for broadcasting below + if self.normalize: + spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,)) + + diffusion = torch.tensor(0) + for dim_1, g in enumerate(first_order_gradient): + dim_1 += 2 + if self.normalize: + # We divide the partial derivative for each vector component at each voxel by the spatial size + # corresponding to that component relative to the spatial size of the vector component with respect + # to which the partial derivative is taken. + g *= pred.shape[dim_1] / spatial_dims + diffusion = diffusion + g**2 + + if self.reduction == LossReduction.MEAN.value: + diffusion = torch.mean(diffusion) # the batch and channel average + elif self.reduction == LossReduction.SUM.value: + diffusion = torch.sum(diffusion) # sum over the batch and channel dims + elif self.reduction != LossReduction.NONE.value: + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + + return diffusion diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 39219e059a..dd132770ec 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -277,9 +277,7 @@ def parzen_windowing_b_spline(self, img: torch.Tensor, order: int) -> tuple[torc if order == 0: weight = weight + (sample_bin_matrix < 0.5) + (sample_bin_matrix == 0.5) * 0.5 elif order == 3: - weight = ( - weight + (4 - 6 * sample_bin_matrix**2 + 3 * sample_bin_matrix**3) * (sample_bin_matrix < 1) / 6 - ) + weight = weight + (4 - 6 * sample_bin_matrix**2 + 3 * sample_bin_matrix**3) * (sample_bin_matrix < 1) / 6 weight = weight + (2 - sample_bin_matrix) ** 3 * (sample_bin_matrix >= 1) * (sample_bin_matrix < 2) / 6 else: raise ValueError(f"Do not support b-spline {order}-order parzen windowing") diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index d9bbf17db3..d727eb0567 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -190,7 +190,7 @@ def compute_hausdorff_distance( y[b, c], distance_metric=distance_metric, spacing=spacing_list[b], - symetric=not directed, + symmetric=not directed, class_index=c, ) percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances] diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index 635eb1bc24..b20b47a1a5 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -253,7 +253,7 @@ def compute_surface_dice( distance_metric=distance_metric, spacing=spacing_list[b], use_subvoxels=use_subvoxels, - symetric=True, + symmetric=True, class_index=c, ) boundary_correct: int | torch.Tensor | float diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 7ce632c588..3cb336d6a0 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -177,7 +177,7 @@ def compute_average_surface_distance( y[b, c], distance_metric=distance_metric, spacing=spacing_list[b], - symetric=symmetric, + symmetric=symmetric, class_index=c, ) surface_distance = torch.cat(distances) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 62e6520b96..e7057256fb 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -38,10 +38,6 @@ binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion") distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt") distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") -cucim_binary_erosion, has_cucim_binary_erosion = optional_import("cucim.skimage.morphology", name="binary_erosion") -cucim_distance_transform_edt, has_cucim_distance_transform_edt = optional_import( - "cucim.core.operations.morphology", name="distance_transform_edt" -) __all__ = [ "ignore_background", @@ -179,6 +175,8 @@ def get_mask_edges( always_return_as_numpy: whether to a numpy array regardless of the input type. If False, return the same type as inputs. """ + # move in the funciton to avoid using all the GPUs + cucim_binary_erosion, has_cucim_binary_erosion = optional_import("cucim.skimage.morphology", name="binary_erosion") if seg_pred.shape != seg_gt.shape: raise ValueError(f"seg_pred and seg_gt should have same shapes, got {seg_pred.shape} and {seg_gt.shape}.") converter: Any @@ -295,7 +293,7 @@ def get_edge_surface_distance( distance_metric: str = "euclidean", spacing: int | float | np.ndarray | Sequence[int | float] | None = None, use_subvoxels: bool = False, - symetric: bool = False, + symmetric: bool = False, class_index: int = -1, ) -> tuple[ tuple[torch.Tensor, torch.Tensor], @@ -314,7 +312,7 @@ def get_edge_surface_distance( See :py:func:`monai.metrics.utils.get_surface_distance`. use_subvoxels: whether to use subvoxel resolution (using the spacing). This will return the areas of the edges. - symetric: whether to compute the surface distance from `y_pred` to `y` and from `y` to `y_pred`. + symmetric: whether to compute the surface distance from `y_pred` to `y` and from `y` to `y_pred`. class_index: The class-index used for context when warning about empty ground truth or prediction. Returns: @@ -338,7 +336,7 @@ def get_edge_surface_distance( " this may result in nan/inf distance." ) distances: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor] - if symetric: + if symmetric: distances = ( get_surface_distance(edges_pred, edges_gt, distance_metric, spacing), get_surface_distance(edges_gt, edges_pred, distance_metric, spacing), diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 10c4ce3d8e..6f96dfd291 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -1024,7 +1024,7 @@ def __init__( self.layers4.append(layer) if self.use_v2: layerc = UnetrBasicBlock( - spatial_dims=3, + spatial_dims=spatial_dims, in_channels=embed_dim * 2**i_layer, out_channels=embed_dim * 2**i_layer, kernel_size=3, diff --git a/monai/networks/nets/transchex.py b/monai/networks/nets/transchex.py index ff27903cef..6bfff3c956 100644 --- a/monai/networks/nets/transchex.py +++ b/monai/networks/nets/transchex.py @@ -12,20 +12,17 @@ from __future__ import annotations import math -import os -import shutil -import tarfile -import tempfile from collections.abc import Sequence import torch from torch import nn +from monai.config.type_definitions import PathLike from monai.utils import optional_import transformers = optional_import("transformers") load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert")[0] -cached_path = optional_import("transformers.file_utils", name="cached_path")[0] +cached_file = optional_import("transformers.utils", name="cached_file")[0] BertEmbeddings = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings")[0] BertLayer = optional_import("transformers.models.bert.modeling_bert", name="BertLayer")[0] @@ -63,44 +60,16 @@ def from_pretrained( state_dict=None, cache_dir=None, from_tf=False, + path_or_repo_id="bert-base-uncased", + filename="pytorch_model.bin", *inputs, **kwargs, ): - archive_file = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz" - resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) - tempdir = None - if os.path.isdir(resolved_archive_file) or from_tf: - serialization_dir = resolved_archive_file - else: - tempdir = tempfile.mkdtemp() - with tarfile.open(resolved_archive_file, "r:gz") as archive: - - def is_within_directory(directory, target): - abs_directory = os.path.abspath(directory) - abs_target = os.path.abspath(target) - - prefix = os.path.commonprefix([abs_directory, abs_target]) - - return prefix == abs_directory - - def safe_extract(tar, path=".", members=None, *, numeric_owner=False): - for member in tar.getmembers(): - member_path = os.path.join(path, member.name) - if not is_within_directory(path, member_path): - raise Exception("Attempted Path Traversal in Tar File") - - tar.extractall(path, members, numeric_owner=numeric_owner) - - safe_extract(archive, tempdir) - serialization_dir = tempdir + weights_path = cached_file(path_or_repo_id, filename, cache_dir=cache_dir) model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs) if state_dict is None and not from_tf: - weights_path = os.path.join(serialization_dir, "pytorch_model.bin") state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None) - if tempdir: - shutil.rmtree(tempdir) if from_tf: - weights_path = os.path.join(serialization_dir, "model.ckpt") return load_tf_weights_in_bert(model, weights_path) old_keys = [] new_keys = [] @@ -304,6 +273,8 @@ def __init__( chunk_size_feed_forward: int = 0, is_decoder: bool = False, add_cross_attention: bool = False, + path_or_repo_id: str | PathLike = "bert-base-uncased", + filename: str = "pytorch_model.bin", ) -> None: """ Args: @@ -315,6 +286,10 @@ def __init__( num_vision_layers: number of vision transformer layers. num_mixed_layers: number of mixed transformer layers. drop_out: fraction of the input units to drop. + path_or_repo_id: This can be either: + - a string, the *model id* of a model repo on huggingface.co. + - a path to a *directory* potentially containing the file. + filename: The name of the file to locate in `path_or_repo`. The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`. @@ -369,6 +344,8 @@ def __init__( num_vision_layers=num_vision_layers, num_mixed_layers=num_mixed_layers, bert_config=bert_config, + path_or_repo_id=path_or_repo_id, + filename=filename, ) self.patch_size = patch_size diff --git a/monai/networks/nets/vqvae.py b/monai/networks/nets/vqvae.py index d4771e203a..f198bfbb2b 100644 --- a/monai/networks/nets/vqvae.py +++ b/monai/networks/nets/vqvae.py @@ -312,10 +312,16 @@ def __init__( channels: Sequence[int] = (96, 96, 192), num_res_layers: int = 3, num_res_channels: Sequence[int] | int = (96, 96, 192), - downsample_parameters: Sequence[Tuple[int, int, int, int]] - | Tuple[int, int, int, int] = ((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1)), - upsample_parameters: Sequence[Tuple[int, int, int, int, int]] - | Tuple[int, int, int, int, int] = ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0)), + downsample_parameters: Sequence[Tuple[int, int, int, int]] | Tuple[int, int, int, int] = ( + (2, 4, 1, 1), + (2, 4, 1, 1), + (2, 4, 1, 1), + ), + upsample_parameters: Sequence[Tuple[int, int, int, int, int]] | Tuple[int, int, int, int, int] = ( + (2, 4, 1, 1, 0), + (2, 4, 1, 1, 0), + (2, 4, 1, 1, 0), + ), num_embeddings: int = 32, embedding_dim: int = 64, embedding_init: str = "normal", diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 56d214c51d..be9441dc4a 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -221,9 +221,8 @@ def __init__( note that `np.pad` treats channel dimension as the first dimension. """ - LazyTransform.__init__(self, lazy) padder = SpatialPad(spatial_size, method, lazy=lazy, **kwargs) - Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) + Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy) class BorderPadd(Padd): @@ -274,9 +273,8 @@ def __init__( note that `np.pad` treats channel dimension as the first dimension. """ - LazyTransform.__init__(self, lazy) padder = BorderPad(spatial_border=spatial_border, lazy=lazy, **kwargs) - Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) + Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy) class DivisiblePadd(Padd): @@ -324,9 +322,8 @@ def __init__( See also :py:class:`monai.transforms.SpatialPad` """ - LazyTransform.__init__(self, lazy) padder = DivisiblePad(k=k, method=method, lazy=lazy, **kwargs) - Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) + Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy) class Cropd(MapTransform, InvertibleTransform, LazyTransform): diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 41fabb35aa..f94f11eca9 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -185,7 +185,17 @@ def track_transform_meta( # not lazy evaluation, directly update the metatensor affine (don't push to the stack) orig_affine = data_t.peek_pending_affine() orig_affine = convert_to_dst_type(orig_affine, affine, dtype=torch.float64)[0] - affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64) + try: + affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64) + except RuntimeError as e: + if orig_affine.ndim > 2: + if data_t.is_batch: + msg = "Transform applied to batched tensor, should be applied to instances only" + else: + msg = "Mismatch affine matrix, ensured that the batch dimension is not included in the calculation." + raise RuntimeError(msg) from e + else: + raise out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu"), dtype=torch.float64) if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)): diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index cd7e4ef090..7222a26fc3 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -414,6 +414,9 @@ def __init__( self.fname_formatter = output_name_formatter self.output_ext = output_ext.lower() or output_format.lower() + self.output_ext = ( + f".{self.output_ext}" if self.output_ext and not self.output_ext.startswith(".") else self.output_ext + ) if isinstance(writer, str): writer_, has_built_in = optional_import("monai.data", name=f"{writer}") # search built-in if not has_built_in: @@ -458,15 +461,23 @@ def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, writ self.write_kwargs.update(write_kwargs) return self - def __call__(self, img: torch.Tensor | np.ndarray, meta_data: dict | None = None): + def __call__( + self, img: torch.Tensor | np.ndarray, meta_data: dict | None = None, filename: str | PathLike | None = None + ): """ Args: img: target data content that save into file. The image should be channel-first, shape: `[C,H,W,[D]]`. meta_data: key-value pairs of metadata corresponding to the data. + filename: str or file-like object which to save img. + If specified, will ignore `self.output_name_formatter` and `self.folder_layout`. """ meta_data = img.meta if isinstance(img, MetaTensor) else meta_data - kw = self.fname_formatter(meta_data, self) - filename = self.folder_layout.filename(**kw) + if filename is not None: + filename = f"{filename}{self.output_ext}" + else: + kw = self.fname_formatter(meta_data, self) + filename = self.folder_layout.filename(**kw) + if meta_data: meta_spatial_shape = ensure_tuple(meta_data.get("spatial_shape", ())) if len(meta_spatial_shape) >= len(img.shape): diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 2322f2123f..5dfbcb0e91 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1562,17 +1562,22 @@ def __init__(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int | self.filter_size = filter_size self.additional_args_for_filter = kwargs - def __call__(self, img: NdarrayOrTensor, meta_dict: dict | None = None) -> NdarrayOrTensor: + def __call__( + self, img: NdarrayOrTensor, meta_dict: dict | None = None, applied_operations: list | None = None + ) -> NdarrayOrTensor: """ Args: img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]] meta_dict: An optional dictionary with metadata + applied_operations: An optional list of operations that have been applied to the data Returns: A MetaTensor with the same shape as `img` and identical metadata """ if isinstance(img, MetaTensor): meta_dict = img.meta + applied_operations = img.applied_operations + img_, prev_type, device = convert_data_type(img, torch.Tensor) ndim = img_.ndim - 1 # assumes channel first format @@ -1582,8 +1587,8 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: dict | None = None) -> Ndarr self.filter = ApplyFilter(self.filter) img_ = self._apply_filter(img_) - if meta_dict: - img_ = MetaTensor(img_, meta=meta_dict) + if meta_dict is not None or applied_operations is not None: + img_ = MetaTensor(img_, meta=meta_dict, applied_operations=applied_operations) else: img_, *_ = convert_data_type(img_, prev_type, device) return img_ diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index ec10bd8537..1cd9ff6323 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1765,9 +1765,9 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N LabelToMaskD = LabelToMaskDict = LabelToMaskd FgBgToIndicesD = FgBgToIndicesDict = FgBgToIndicesd ClassesToIndicesD = ClassesToIndicesDict = ClassesToIndicesd -ConvertToMultiChannelBasedOnBratsClassesD = ( - ConvertToMultiChannelBasedOnBratsClassesDict -) = ConvertToMultiChannelBasedOnBratsClassesd +ConvertToMultiChannelBasedOnBratsClassesD = ConvertToMultiChannelBasedOnBratsClassesDict = ( + ConvertToMultiChannelBasedOnBratsClassesd +) AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld TorchVisionD = TorchVisionDict = TorchVisiond RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond diff --git a/monai/utils/dist.py b/monai/utils/dist.py index 20f09628ac..2418b43591 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -50,18 +50,15 @@ def get_dist_device(): @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[True]) -> torch.Tensor: - ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[True]) -> torch.Tensor: ... @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[False]) -> list[torch.Tensor]: - ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[False]) -> list[torch.Tensor]: ... @overload -def evenly_divisible_all_gather(data: torch.Tensor, concat: bool) -> torch.Tensor | list[torch.Tensor]: - ... +def evenly_divisible_all_gather(data: torch.Tensor, concat: bool) -> torch.Tensor | list[torch.Tensor]: ... def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = True) -> torch.Tensor | list[torch.Tensor]: diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 4f2501a7ee..81f582daef 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -103,13 +103,11 @@ def star_zip_with(op, *vals): @overload -def first(iterable: Iterable[T], default: T) -> T: - ... +def first(iterable: Iterable[T], default: T) -> T: ... @overload -def first(iterable: Iterable[T]) -> T | None: - ... +def first(iterable: Iterable[T]) -> T | None: ... def first(iterable: Iterable[T], default: T | None = None) -> T | None: diff --git a/requirements-dev.txt b/requirements-dev.txt index 6332d5b0a5..f8bc9d5a3e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,7 @@ # Full requirements for developments -r requirements-min.txt pytorch-ignite==0.4.11 -gdown>=4.4.0 +gdown>=4.4.0, <=4.6.3 scipy>=1.7.1 itk>=5.2 nibabel @@ -27,13 +27,13 @@ ninja torchvision psutil cucim>=23.2.0; platform_system == "Linux" -openslide-python==1.1.2 +openslide-python imagecodecs; platform_system == "Linux" or platform_system == "Darwin" tifffile; platform_system == "Linux" or platform_system == "Darwin" pandas requests einops -transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157 +transformers>=4.36.0 mlflow>=1.28.0 clearml>=1.10.0rc0 matplotlib!=3.5.0 diff --git a/setup.cfg b/setup.cfg index 0069214de3..4180ced917 100644 --- a/setup.cfg +++ b/setup.cfg @@ -174,6 +174,7 @@ max_line_length = 120 # B907 https://github.com/Project-MONAI/MONAI/issues/5868 # B908 https://github.com/Project-MONAI/MONAI/issues/6503 # B036 https://github.com/Project-MONAI/MONAI/issues/7396 +# E704 https://github.com/Project-MONAI/MONAI/issues/7421 ignore = E203 E501 @@ -188,6 +189,7 @@ ignore = B907 B908 B036 + E704 per_file_ignores = __init__.py: F401, __main__.py: F401 exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,.venv,_version.py diff --git a/tests/min_tests.py b/tests/min_tests.py index 8128bb7b84..3a143df84b 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -154,6 +154,7 @@ def run_testsuit(): "test_plot_2d_or_3d_image", "test_png_rw", "test_prepare_batch_default", + "test_prepare_batch_diffusion", "test_prepare_batch_extra_input", "test_prepare_batch_hovernet", "test_rand_grid_patch", diff --git a/tests/padders.py b/tests/padders.py index 02d7b40af6..ae1153bdfd 100644 --- a/tests/padders.py +++ b/tests/padders.py @@ -136,6 +136,9 @@ def pad_test_pending_ops(self, input_param, input_shape): # TODO: mode="bilinear" may report error overrides = {"mode": "nearest", "padding_mode": mode[1], "align_corners": False} result = apply_pending(pending_result, overrides=overrides)[0] + # lazy in constructor + pad_fn_lazy = self.Padder(mode=mode[0], lazy=True, **input_param) + self.assertTrue(pad_fn_lazy.lazy) # compare assert_allclose(result, expected, rtol=1e-5) if isinstance(result, MetaTensor) and not isinstance(pad_fn, MapTransform): diff --git a/tests/test_diffusion_loss.py b/tests/test_diffusion_loss.py new file mode 100644 index 0000000000..05dfab95fb --- /dev/null +++ b/tests/test_diffusion_loss.py @@ -0,0 +1,116 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses.deform import DiffusionLoss + +device = "cuda" if torch.cuda.is_available() else "cpu" + +TEST_CASES = [ + # all first partials are zero, so the diffusion loss is also zero + [{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0], + # all first partials are one, so the diffusion loss is also one + [{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 1.0], + # before expansion, the first partials are 2, 4, 6, so the diffusion loss is (2^2 + 4^2 + 6^2) / 3 = 18.67 + [ + {"normalize": False}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 56.0 / 3.0, + ], + # same as the previous case + [ + {"normalize": False}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, + 56.0 / 3.0, + ], + # same as the previous case + [{"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0], + # we have shown in the demo notebook that + # diffusion loss is scale-invariant when the all axes have the same resolution + [ + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 56.0 / 3.0, + ], + [ + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, + 56.0 / 3.0, + ], + [{"normalize": True}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0], + # for the following case, consider the following 2D matrix: + # tensor([[[[0, 1, 2], + # [1, 2, 3], + # [2, 3, 4], + # [3, 4, 5], + # [4, 5, 6]], + # [[0, 1, 2], + # [1, 2, 3], + # [2, 3, 4], + # [3, 4, 5], + # [4, 5, 6]]]]) + # the first partials wrt x are all ones, and so are the first partials wrt y + # the diffusion loss, when normalization is not applied, is 1^2 + 1^2 = 2 + [{"normalize": False}, {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, 2.0], + # consider the same matrix, this time with normalization applied, using the same notation as in the demo notebook, + # the coefficients to be divided out are (1, 5/3) for partials wrt x and (3/5, 1) for partials wrt y + # the diffusion loss is then (1/1)^2 + (1/(5/3))^2 + (1/(3/5))^2 + (1/1)^2 = (1 + 9/25 + 25/9 + 1) / 2 = 2.5689 + [ + {"normalize": True}, + {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, + (1.0 + 9.0 / 25.0 + 25.0 / 9.0 + 1.0) / 2.0, + ], +] + + +class TestDiffusionLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = DiffusionLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) + + def test_ill_shape(self): + loss = DiffusionLoss() + # not in 3-d, 4-d, 5-d + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): + loss.forward(torch.ones((1, 3), device=device)) + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): + loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device)) + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): + loss.forward(torch.ones((1, 3, 2, 5, 5), device=device)) + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): + loss.forward(torch.ones((1, 3, 5, 2, 5))) + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): + loss.forward(torch.ones((1, 3, 5, 5, 2))) + + # number of vector components unequal to number of spatial dims + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + + def test_ill_opts(self): + pred = torch.rand(1, 3, 5, 5, 5).to(device=device) + with self.assertRaisesRegex(ValueError, ""): + DiffusionLoss(reduction="unknown")(pred) + with self.assertRaisesRegex(ValueError, ""): + DiffusionLoss(reduction=None)(pred) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_hilbert_transform.py b/tests/test_hilbert_transform.py index 4c49aecd8b..68fa0b1192 100644 --- a/tests/test_hilbert_transform.py +++ b/tests/test_hilbert_transform.py @@ -180,15 +180,17 @@ def test_value(self, arguments, image, expected_data, atol): @SkipIfNoModule("torch.fft") class TestHilbertTransformGPU(unittest.TestCase): @parameterized.expand( - [] - if not torch.cuda.is_available() - else [ - TEST_CASE_1D_SINE_GPU, - TEST_CASE_2D_SINE_GPU, - TEST_CASE_3D_SINE_GPU, - TEST_CASE_1D_2CH_SINE_GPU, - TEST_CASE_2D_2CH_SINE_GPU, - ], + ( + [] + if not torch.cuda.is_available() + else [ + TEST_CASE_1D_SINE_GPU, + TEST_CASE_2D_SINE_GPU, + TEST_CASE_3D_SINE_GPU, + TEST_CASE_1D_2CH_SINE_GPU, + TEST_CASE_2D_2CH_SINE_GPU, + ] + ), skip_on_empty=True, ) def test_value(self, arguments, image, expected_data, atol): diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index 841a5d5cd5..985ea95e79 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -17,6 +17,7 @@ import torch from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms import ImageFilter, ImageFilterd, RandImageFilter, RandImageFilterd @@ -115,6 +116,21 @@ def test_call_3d(self, filter_name): out_tensor = filter(SAMPLE_IMAGE_3D) self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_3D.shape[1:]) + def test_pass_applied_operations(self): + "Test that applied operations are passed through" + applied_operations = ["op1", "op2"] + image = MetaTensor(SAMPLE_IMAGE_2D, applied_operations=applied_operations) + filter = ImageFilter(SUPPORTED_FILTERS[0], 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(image) + self.assertEqual(out_tensor.applied_operations, applied_operations) + + def test_pass_empty_metadata_dict(self): + "Test that applied operations are passed through" + image = MetaTensor(SAMPLE_IMAGE_2D, meta={}) + filter = ImageFilter(SUPPORTED_FILTERS[0], 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(image) + self.assertTrue(isinstance(out_tensor, MetaTensor)) + class TestImageFilterDict(unittest.TestCase): @parameterized.expand(SUPPORTED_FILTERS) diff --git a/tests/test_integration_workflows_adversarial.py b/tests/test_integration_workflows_adversarial.py new file mode 100644 index 0000000000..f323fc9917 --- /dev/null +++ b/tests/test_integration_workflows_adversarial.py @@ -0,0 +1,173 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import shutil +import tempfile +import unittest +from glob import glob + +import numpy as np +import torch + +import monai +from monai.data import create_test_image_2d +from monai.engines import AdversarialTrainer +from monai.handlers import CheckpointSaver, StatsHandler, TensorBoardStatsHandler +from monai.networks.nets import AutoEncoder, Discriminator +from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, RandFlipd, ScaleIntensityd +from monai.utils import AdversarialKeys as Keys +from monai.utils import CommonKeys, optional_import, set_determinism +from tests.utils import DistTestCase, TimedCall, skip_if_quick + +nib, has_nibabel = optional_import("nibabel") + + +def run_training_test(root_dir, device="cuda:0"): + learning_rate = 2e-4 + real_label = 1 + fake_label = 0 + + real_images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) + train_files = [{CommonKeys.IMAGE: img, CommonKeys.LABEL: img} for img in zip(real_images)] + + # prepare real data + train_transforms = Compose( + [ + LoadImaged(keys=[CommonKeys.IMAGE, CommonKeys.LABEL]), + EnsureChannelFirstd(keys=[CommonKeys.IMAGE, CommonKeys.LABEL], channel_dim=2), + ScaleIntensityd(keys=[CommonKeys.IMAGE]), + RandFlipd(keys=[CommonKeys.IMAGE, CommonKeys.LABEL], prob=0.5), + ] + ) + train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) + train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) + + # Create Discriminator + discriminator_net = Discriminator( + in_shape=(1, 64, 64), channels=(8, 16, 32, 64, 1), strides=(2, 2, 2, 2, 1), num_res_units=1, kernel_size=5 + ).to(device) + discriminator_opt = torch.optim.Adam(discriminator_net.parameters(), learning_rate) + discriminator_loss_criterion = torch.nn.BCELoss() + + def discriminator_loss(real_logits, fake_logits): + real_target = real_logits.new_full((real_logits.shape[0], 1), real_label) + fake_target = fake_logits.new_full((fake_logits.shape[0], 1), fake_label) + real_loss = discriminator_loss_criterion(real_logits, real_target) + fake_loss = discriminator_loss_criterion(fake_logits.detach(), fake_target) + return torch.div(torch.add(real_loss, fake_loss), 2) + + # Create Generator + generator_network = AutoEncoder( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(8, 16, 32, 64), + strides=(2, 2, 2, 2), + num_res_units=1, + num_inter_units=1, + ) + generator_network = generator_network.to(device) + generator_optimiser = torch.optim.Adam(generator_network.parameters(), learning_rate) + generator_loss_criterion = torch.nn.MSELoss() + + def reconstruction_loss(recon_images, real_images): + return generator_loss_criterion(recon_images, real_images) + + def generator_loss(fake_logits): + fake_target = fake_logits.new_full((fake_logits.shape[0], 1), real_label) + recon_loss = discriminator_loss_criterion(fake_logits.detach(), fake_target) + return recon_loss + + key_train_metric = None + + train_handlers = [ + StatsHandler( + name="training_loss", + output_transform=lambda x: { + Keys.RECONSTRUCTION_LOSS: x[Keys.RECONSTRUCTION_LOSS], + Keys.DISCRIMINATOR_LOSS: x[Keys.DISCRIMINATOR_LOSS], + Keys.GENERATOR_LOSS: x[Keys.GENERATOR_LOSS], + }, + ), + TensorBoardStatsHandler( + log_dir=root_dir, + tag_name="training_loss", + output_transform=lambda x: { + Keys.RECONSTRUCTION_LOSS: x[Keys.RECONSTRUCTION_LOSS], + Keys.DISCRIMINATOR_LOSS: x[Keys.DISCRIMINATOR_LOSS], + Keys.GENERATOR_LOSS: x[Keys.GENERATOR_LOSS], + }, + ), + CheckpointSaver( + save_dir=root_dir, + save_dict={"g_net": generator_network, "d_net": discriminator_net}, + save_interval=2, + epoch_level=True, + ), + ] + + num_epochs = 5 + + trainer = AdversarialTrainer( + device=device, + max_epochs=num_epochs, + train_data_loader=train_loader, + g_network=generator_network, + g_optimizer=generator_optimiser, + g_loss_function=generator_loss, + recon_loss_function=reconstruction_loss, + d_network=discriminator_net, + d_optimizer=discriminator_opt, + d_loss_function=discriminator_loss, + non_blocking=True, + key_train_metric=key_train_metric, + train_handlers=train_handlers, + ) + trainer.run() + + return trainer.state + + +@skip_if_quick +@unittest.skipUnless(has_nibabel, "Requires nibabel library.") +class IntegrationWorkflowsAdversarialTrainer(DistTestCase): + def setUp(self): + set_determinism(seed=0) + + self.data_dir = tempfile.mkdtemp() + for i in range(40): + im, _ = create_test_image_2d(64, 64, num_objs=3, rad_max=14, num_seg_classes=1, channel_dim=-1) + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(self.data_dir, f"img{i:d}.nii.gz")) + + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0") + monai.config.print_config() + + def tearDown(self): + set_determinism(seed=None) + shutil.rmtree(self.data_dir) + + @TimedCall(seconds=200, daemon=False) + def test_training(self): + torch.manual_seed(0) + + finish_state = run_training_test(self.data_dir, device=self.device) + + # Assert AdversarialTrainer training finished + self.assertEqual(finish_state.iteration, 100) + self.assertEqual(finish_state.epoch, 5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_prepare_batch_diffusion.py b/tests/test_prepare_batch_diffusion.py new file mode 100644 index 0000000000..d969c06368 --- /dev/null +++ b/tests/test_prepare_batch_diffusion.py @@ -0,0 +1,104 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.engines import SupervisedEvaluator +from monai.engines.utils import DiffusionPrepareBatch +from monai.inferers import DiffusionInferer +from monai.networks.nets import DiffusionModelUNet +from monai.networks.schedulers import DDPMScheduler + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "channels": [8], + "norm_num_groups": 8, + "attention_levels": [True], + "num_res_blocks": 1, + "num_head_channels": 8, + }, + (2, 1, 8, 8, 8), + ], +] + + +class TestPrepareBatchDiffusion(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_output_sizes(self, input_args, image_size): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataloader = [{"image": torch.randn(image_size).to(device)}] + scheduler = DDPMScheduler(num_train_timesteps=20) + inferer = DiffusionInferer(scheduler=scheduler) + network = DiffusionModelUNet(**input_args).to(device) + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=dataloader, + epoch_length=1, + network=network, + inferer=inferer, + non_blocking=True, + prepare_batch=DiffusionPrepareBatch(num_train_timesteps=20), + decollate=False, + ) + evaluator.run() + output = evaluator.state.output + # check shapes are the same + self.assertEqual(output["pred"].shape, image_size) + self.assertEqual(output["label"].shape, output["image"].shape) + + @parameterized.expand(TEST_CASES) + def test_conditioning(self, input_args, image_size): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataloader = [{"image": torch.randn(image_size).to(device), "context": torch.randn((2, 4, 3)).to(device)}] + scheduler = DDPMScheduler(num_train_timesteps=20) + inferer = DiffusionInferer(scheduler=scheduler) + network = DiffusionModelUNet(**input_args, with_conditioning=True, cross_attention_dim=3).to(device) + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=dataloader, + epoch_length=1, + network=network, + inferer=inferer, + non_blocking=True, + prepare_batch=DiffusionPrepareBatch(num_train_timesteps=20, condition_name="context"), + decollate=False, + ) + evaluator.run() + output = evaluator.state.output + # check shapes are the same + self.assertEqual(output["pred"].shape, image_size) + self.assertEqual(output["label"].shape, output["image"].shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_save_image.py b/tests/test_save_image.py index ba94ab5087..d88db201ce 100644 --- a/tests/test_save_image.py +++ b/tests/test_save_image.py @@ -37,6 +37,8 @@ False, ] +TEST_CASE_5 = [torch.randint(0, 255, (3, 2, 4, 5), dtype=torch.uint8), ".dcm", False] + @unittest.skipUnless(has_itk, "itk not installed") class TestSaveImage(unittest.TestCase): @@ -58,6 +60,20 @@ def test_saved_content(self, test_data, meta_data, output_ext, resample): filepath = "testfile0" if meta_data is not None else "0" self.assertTrue(os.path.exists(os.path.join(tempdir, filepath + "_trans" + output_ext))) + @parameterized.expand([TEST_CASE_5]) + def test_saved_content_with_filename(self, test_data, output_ext, resample): + with tempfile.TemporaryDirectory() as tempdir: + trans = SaveImage( + output_dir=tempdir, + output_ext=output_ext, + resample=resample, + separate_folder=False, # test saving into the same folder + ) + filename = str(os.path.join(tempdir, "test")) + trans(test_data, filename=filename) + + self.assertTrue(os.path.exists(filename + output_ext)) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_set_visible_devices.py b/tests/test_set_visible_devices.py index 53703e107a..993e8a4ac2 100644 --- a/tests/test_set_visible_devices.py +++ b/tests/test_set_visible_devices.py @@ -35,6 +35,13 @@ def test_visible_devices(self): ) self.assertEqual(num_gpus_before, num_gpus_after) + # test import monai won't affect setting CUDA_VISIBLE_DEVICES + num_gpus_after_monai = self.run_process_and_get_exit_code( + 'python -c "import os; import torch; import monai; ' + + "os.environ['CUDA_VISIBLE_DEVICES'] = '0'; exit(torch.cuda.device_count())\"" + ) + self.assertEqual(num_gpus_after_monai, 1) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_spacing.py b/tests/test_spacing.py index 1ff1518297..8b664641d7 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -74,9 +74,11 @@ torch.ones((1, 2, 1, 2)), # data torch.tensor([[2, 1, 0, 4], [-1, -3, 0, 5], [0, 0, 2.0, 5], [0, 0, 0, 1]]), {}, - torch.tensor([[[[0.75, 0.75]], [[0.75, 0.75]], [[0.75, 0.75]]]]) - if USE_COMPILED - else torch.tensor([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]]), + ( + torch.tensor([[[[0.75, 0.75]], [[0.75, 0.75]], [[0.75, 0.75]]]]) + if USE_COMPILED + else torch.tensor([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]]) + ), *device, ] )