diff --git a/monai/inferers/__init__.py b/monai/inferers/__init__.py index ad0ec77bcf..e35fa44ae4 100644 --- a/monai/inferers/__init__.py +++ b/monai/inferers/__init__.py @@ -11,7 +11,15 @@ from __future__ import annotations -from .inferer import Inferer, PatchInferer, SaliencyInferer, SimpleInferer, SliceInferer, SlidingWindowInferer +from .inferer import ( + Inferer, + PatchInferer, + SaliencyInferer, + SimpleInferer, + SliceInferer, + SlidingWindowInferer, + SlidingWindowInfererAdapt, +) from .merger import AvgMerger, Merger from .splitter import SlidingWindowSplitter, Splitter from .utils import sliding_window_inference diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 9cf1ecc73c..1289c9db8f 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -20,6 +20,7 @@ import torch import torch.nn as nn +from monai.apps.utils import get_logger from monai.data.meta_tensor import MetaTensor from monai.inferers.merger import AvgMerger, Merger from monai.inferers.splitter import Splitter @@ -27,7 +28,17 @@ from monai.utils import BlendMode, PatchKeys, PytorchPadMode, ensure_tuple, optional_import from monai.visualize import CAM, GradCAM, GradCAMpp -__all__ = ["Inferer", "PatchInferer", "SimpleInferer", "SlidingWindowInferer", "SaliencyInferer", "SliceInferer"] +logger = get_logger(__name__) + +__all__ = [ + "Inferer", + "PatchInferer", + "SimpleInferer", + "SlidingWindowInferer", + "SaliencyInferer", + "SliceInferer", + "SlidingWindowInfererAdapt", +] class Inferer(ABC): @@ -448,7 +459,9 @@ def __call__( """ - device = self.device + device = kwargs.pop("device", self.device) + buffer_steps = kwargs.pop("buffer_steps", self.buffer_steps) + if device is None and self.cpu_thresh is not None and inputs.shape[2:].numel() > self.cpu_thresh: device = "cpu" # stitch in cpu memory if image is too large @@ -467,13 +480,96 @@ def __call__( self.progress, self.roi_weight_map, None, - self.buffer_steps, + buffer_steps, self.buffer_dim, *args, **kwargs, ) +class SlidingWindowInfererAdapt(SlidingWindowInferer): + """ + SlidingWindowInfererAdapt extends SlidingWindowInferer to automatically switch to buffered and then to CPU stitching, + when OOM on GPU. It also records a size of such large images to automatically + try CPU stitching for the next large image of a similar size. If the stitching 'device' input parameter is provided, + automatic adaptation won't be attempted, please keep the default option device = None for adaptive behavior. + Note: the output might be on CPU (even if the input was on GPU), if the GPU memory was not sufficient. + + """ + + def __call__( + self, + inputs: torch.Tensor, + network: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]], + *args: Any, + **kwargs: Any, + ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]: + """ + + Args: + inputs: model input data for inference. + network: target model to execute inference. + supports callables such as ``lambda x: my_torch_model(x, additional_config)`` + args: optional args to be passed to ``network``. + kwargs: optional keyword args to be passed to ``network``. + + """ + + # if device is provided, use without any adaptations + if self.device is not None: + return super().__call__(inputs, network, *args, **kwargs) + + skip_buffer = self.buffer_steps is not None and self.buffer_steps <= 0 + cpu_cond = self.cpu_thresh is not None and inputs.shape[2:].numel() > self.cpu_thresh + gpu_stitching = inputs.is_cuda and not cpu_cond + buffered_stitching = inputs.is_cuda and cpu_cond and not skip_buffer + buffer_steps = max(1, self.buffer_steps) if self.buffer_steps is not None else 1 + + for _ in range(10): # at most 10 trials + try: + return super().__call__( + inputs, + network, + device=inputs.device if gpu_stitching else torch.device("cpu"), + buffer_steps=buffer_steps if buffered_stitching else None, + *args, + **kwargs, + ) + except RuntimeError as e: + if not gpu_stitching and not buffered_stitching or "OutOfMemoryError" not in str(type(e).__name__): + raise e + + logger.info(e) + + if gpu_stitching: # if failed on gpu + gpu_stitching = False + self.cpu_thresh = inputs.shape[2:].numel() - 1 # update thresh + + if skip_buffer: + buffered_stitching = False + logger.warning(f"GPU stitching failed, attempting on CPU, image dim {inputs.shape}..") + + else: + buffered_stitching = True + self.buffer_steps = buffer_steps + logger.warning( + f"GPU stitching failed, attempting with buffer {buffer_steps}, image dim {inputs.shape}.." + ) + elif buffer_steps > 1: + buffer_steps = max(1, buffer_steps // 2) + self.buffer_steps = buffer_steps + logger.warning( + f"GPU buffered stitching failed, image dim {inputs.shape} reducing buffer to {buffer_steps}" + ) + else: + buffered_stitching = False + self.buffer_steps = 0 # disable future buffer attempts + logger.warning(f"GPU buffered stitching failed, attempting on CPU, image dim {inputs.shape}") + raise RuntimeError( # not possible to finish after the trials + f"SlidingWindowInfererAdapt {skip_buffer} {cpu_cond} {gpu_stitching} {buffered_stitching} {buffer_steps}" + ) + + class SaliencyInferer(Inferer): """ SaliencyInferer is inference with activation maps. diff --git a/monai/transforms/lazy/utils.py b/monai/transforms/lazy/utils.py index 94382c9381..fa1bb6d48e 100644 --- a/monai/transforms/lazy/utils.py +++ b/monai/transforms/lazy/utils.py @@ -216,6 +216,7 @@ def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None = and allclose(convert_to_numpy(in_shape, wrap_sequence=True), out_spatial_size) ): img.affine = call_kwargs["dst_affine"] + img = img.to(torch.float32) # consistent with monai.transforms.spatial.functional.spatial_resample return img img = monai.transforms.crop_or_pad_nd(img, matrix_np, out_spatial_size, mode=call_kwargs["padding_mode"]) img = img.to(torch.float32) # consistent with monai.transforms.spatial.functional.spatial_resample diff --git a/tests/test_resample.py b/tests/test_resample.py index 2f9ef2cecb..c90dc5f13d 100644 --- a/tests/test_resample.py +++ b/tests/test_resample.py @@ -28,7 +28,10 @@ def rotate_90_2d(): return t -RESAMPLE_FUNCTION_CASES = [(get_arange_img((3, 3)), rotate_90_2d(), [[0, 3, 6], [0, 3, 6], [0, 3, 6]])] +RESAMPLE_FUNCTION_CASES = [ + (get_arange_img((3, 3)), rotate_90_2d(), [[0, 3, 6], [0, 3, 6], [0, 3, 6]]), + (get_arange_img((3, 3)), torch.eye(3), get_arange_img((3, 3))[0]), +] class TestResampleFunction(unittest.TestCase): diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index 117ad341c5..f9d49361a6 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -19,7 +19,7 @@ from parameterized import parameterized from monai.data.utils import list_data_collate -from monai.inferers import SlidingWindowInferer, sliding_window_inference +from monai.inferers import SlidingWindowInferer, SlidingWindowInfererAdapt, sliding_window_inference from monai.utils import optional_import from tests.utils import TEST_TORCH_AND_META_TENSORS, skip_if_no_cuda, test_is_quick @@ -305,6 +305,11 @@ def compute(data, test1, test2): )(inputs, compute, t1, test2=t2) np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4) + result = SlidingWindowInfererAdapt( + roi_shape, sw_batch_size, overlap=0.5, mode="constant", cval=-1, progress=has_tqdm + )(inputs, compute, t1, test2=t2) + np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4) + def test_multioutput(self): device = "cuda" if torch.cuda.is_available() else "cpu:0" inputs = torch.ones((1, 6, 20, 20)).to(device=device)