diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 456fa10c41..8ae028bbd2 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -34,6 +34,10 @@ jobs: run: | which python python -m pip install --upgrade pip wheel + pip uninstall -y monai + pip uninstall -y monai + pip uninstall -y monai-weekly + pip uninstall -y monai-weekly python -m pip install --upgrade torch torchvision torchaudio python -m pip install -r requirements-dev.txt rm -rf /github/home/.cache/torch/hub/mmars/ @@ -51,6 +55,7 @@ jobs: python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))' # test auto3dseg + echo "test tag algo" BUILD_MONAI=0 ./runtests.sh --build python -m tests.test_auto3dseg_ensemble python -m tests.test_auto3dseg_hpo @@ -58,11 +63,16 @@ jobs: python -m tests.test_integration_gpu_customization # test latest template + echo "test latest algo" cd ../ + rm -rf research-contributions + rm -rf algorithm_templates git clone --depth 1 --branch main --single-branch https://github.com/Project-MONAI/research-contributions.git ls research-contributions/ - cp -r research-contributions/auto3dseg/algorithm_templates ../MONAI/ - cd research-contributions && git log -1 && cd .. + cp -r research-contributions/auto3dseg/algorithm_templates MONAI/ + cd research-contributions && git log -1 && cd ../MONAI + pwd + ls -ll export OMP_NUM_THREADS=4 export MKL_NUM_THREADS=4 export MONAI_TESTING_ALGO_TEMPLATE=algorithm_templates @@ -72,6 +82,9 @@ jobs: python -m tests.test_integration_gpu_customization # the other tests + echo "the other tests" + pwd + ls -ll BUILD_MONAI=1 ./runtests.sh --build --net BUILD_MONAI=1 ./runtests.sh --build --unittests if pgrep python; then pkill python; fi diff --git a/.gitignore b/.gitignore index 4e235c7774..2fd28bbfc7 100644 --- a/.gitignore +++ b/.gitignore @@ -149,3 +149,4 @@ tests/testing_data/CT_2D_head_moving.mha # profiling results *.prof +runs diff --git a/monai/apps/pathology/inferers/inferer.py b/monai/apps/pathology/inferers/inferer.py index da4ac4fd7a..7a60c23aa2 100644 --- a/monai/apps/pathology/inferers/inferer.py +++ b/monai/apps/pathology/inferers/inferer.py @@ -176,6 +176,8 @@ def __call__( self.progress, self.roi_weight_map, self.process_output, + self.buffer_steps, + self.buffer_dim, *args, **kwargs, ) diff --git a/monai/data/utils.py b/monai/data/utils.py index d5dddb5d55..135a35f205 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -163,7 +163,7 @@ def iter_patch_slices( def dense_patch_slices( - image_size: Sequence[int], patch_size: Sequence[int], scan_interval: Sequence[int] + image_size: Sequence[int], patch_size: Sequence[int], scan_interval: Sequence[int], return_slice: bool = True ) -> list[tuple[slice, ...]]: """ Enumerate all slices defining ND patches of size `patch_size` from an `image_size` input image. @@ -172,6 +172,7 @@ def dense_patch_slices( image_size: dimensions of image to iterate over patch_size: size of patches to generate slices scan_interval: dense patch sampling interval + return_slice: whether to return a list of slices (or tuples of indices), defaults to True Returns: a list of slice objects defining each patch @@ -199,7 +200,9 @@ def dense_patch_slices( dim_starts.append(start_idx) starts.append(dim_starts) out = np.asarray([x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T - return [tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out] + if return_slice: + return [tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out] + return [tuple((s, s + patch_size[d]) for d, s in enumerate(x)) for x in out] # type: ignore def iter_patch_position( @@ -1056,6 +1059,7 @@ def compute_importance_map( mode: BlendMode | str = BlendMode.CONSTANT, sigma_scale: Sequence[float] | float = 0.125, device: torch.device | int | str = "cpu", + dtype: torch.dtype | str | None = torch.float32, ) -> torch.Tensor: """Get importance map for different weight modes. @@ -1070,6 +1074,7 @@ def compute_importance_map( sigma_scale: Sigma_scale to calculate sigma for each dimension (sigma = sigma_scale * dim_size). Used for gaussian mode only. device: Device to put importance map on. + dtype: Data type of the output importance map. Raises: ValueError: When ``mode`` is not one of ["constant", "gaussian"]. @@ -1096,6 +1101,9 @@ def compute_importance_map( raise ValueError( f"Unsupported mode: {mode}, available options are [{BlendMode.CONSTANT}, {BlendMode.CONSTANT}]." ) + # handle non-positive weights + min_non_zero = max(torch.min(importance_map).item(), 1e-3) + importance_map = torch.clamp_(importance_map.to(torch.float), min=min_non_zero).to(dtype) return importance_map diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 03b5e7a75f..952872b5ba 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -366,6 +366,9 @@ class SlidingWindowInferer(Inferer): cpu_thresh: when provided, dynamically switch to stitching on cpu (to save gpu memory) when input image volume is larger than this threshold (in pixels/voxels). Otherwise use ``"device"``. Thus, the output may end-up on either cpu or gpu. + buffer_steps: the number of sliding window iterations before writing the outputs to ``device``. + default is None, no buffer. + buffer_dim: the dimension along which the buffer are created, default is 0. Note: ``sw_batch_size`` denotes the max number of windows per network inference iteration, @@ -387,6 +390,8 @@ def __init__( progress: bool = False, cache_roi_weight_map: bool = False, cpu_thresh: int | None = None, + buffer_steps: int | None = None, + buffer_dim: int = 0, ) -> None: super().__init__() self.roi_size = roi_size @@ -400,6 +405,8 @@ def __init__( self.device = device self.progress = progress self.cpu_thresh = cpu_thresh + self.buffer_steps = buffer_steps + self.buffer_dim = buffer_dim # compute_importance_map takes long time when computing on cpu. We thus # compute it once if it's static and then save it for future usage @@ -415,7 +422,8 @@ def __init__( warnings.warn("cache_roi_weight_map=True, but cache is not created. (dynamic roi_size?)") except BaseException as e: raise RuntimeError( - "Seems to be OOM. Please try smaller roi_size, or use mode='constant' instead of mode='gaussian'. " + f"roi size {self.roi_size}, mode={mode}, sigma_scale={sigma_scale}, device={device}\n" + "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'." ) from e def __call__( @@ -455,6 +463,8 @@ def __call__( self.progress, self.roi_weight_map, None, + self.buffer_steps, + self.buffer_dim, *args, **kwargs, ) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index c4405911d0..59fb479904 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -11,28 +11,31 @@ from __future__ import annotations -import warnings +import itertools from collections.abc import Callable, Mapping, Sequence -from typing import Any +from typing import Any, Iterable +import numpy as np import torch import torch.nn.functional as F from monai.data.meta_tensor import MetaTensor from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size -from monai.transforms import Resize from monai.utils import ( BlendMode, PytorchPadMode, convert_data_type, convert_to_dst_type, ensure_tuple, + ensure_tuple_rep, fall_back_tuple, look_up_option, optional_import, + pytorch_after, ) tqdm, _ = optional_import("tqdm", name="tqdm") +_nearest_mode = "nearest-exact" if pytorch_after(1, 11) else "nearest" __all__ = ["sliding_window_inference"] @@ -42,7 +45,7 @@ def sliding_window_inference( roi_size: Sequence[int] | int, sw_batch_size: int, predictor: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]], - overlap: float = 0.25, + overlap: Sequence[float] | float = 0.25, mode: BlendMode | str = BlendMode.CONSTANT, sigma_scale: Sequence[float] | float = 0.125, padding_mode: PytorchPadMode | str = PytorchPadMode.CONSTANT, @@ -52,6 +55,8 @@ def sliding_window_inference( progress: bool = False, roi_weight_map: torch.Tensor | None = None, process_fn: Callable | None = None, + buffer_steps: int | None = None, + buffer_dim: int = -1, *args: Any, **kwargs: Any, ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]: @@ -87,7 +92,7 @@ def sliding_window_inference( to ensure the scaled output ROI sizes are still integers. If the `predictor`'s input and output spatial sizes are different, we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension. - overlap: Amount of overlap between scans. + overlap: Amount of overlap between scans along each spatial dimension, defaults to ``0.25``. mode: {``"constant"``, ``"gaussian"``} How to blend output of overlapping windows. Defaults to ``"constant"``. @@ -113,6 +118,12 @@ def sliding_window_inference( roi_weight_map: pre-computed (non-negative) weight map for each ROI. If not given, and ``mode`` is not `constant`, this map will be computed on the fly. process_fn: process inference output and adjust the importance map per window + buffer_steps: the number of sliding window iterations along the ``buffer_dim`` + to be buffered on ``sw_device`` before writing to ``device``. + default is None, no buffering. For the buffer dim, when spatial size is divisible by buffer_steps*roi_size, + (i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency. + buffer_dim: the spatial dimension along which the buffers are created. + 0 indicates the first spatial dimension. Default is -1, the last spatial dimension. args: optional args to be passed to ``predictor``. kwargs: optional keyword args to be passed to ``predictor``. @@ -120,21 +131,31 @@ def sliding_window_inference( - input must be channel-first and have a batch dim, supports N-D sliding window. """ - compute_dtype = inputs.dtype + buffered = buffer_steps is not None and buffer_steps > 0 num_spatial_dims = len(inputs.shape) - 2 - if overlap < 0 or overlap >= 1: - raise ValueError("overlap must be >= 0 and < 1.") + if buffered: + if buffer_dim < -num_spatial_dims or buffer_dim > num_spatial_dims: + raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.") + if buffer_dim < 0: + buffer_dim += num_spatial_dims + overlap = ensure_tuple_rep(overlap, num_spatial_dims) + for o in overlap: + if o < 0 or o >= 1: + raise ValueError(f"overlap must be >= 0 and < 1, got {overlap}.") + compute_dtype = inputs.dtype # determine image spatial size and batch size # Note: all input images must have the same image size and batch size batch_size, _, *image_size_ = inputs.shape + device = device or inputs.device + sw_device = sw_device or inputs.device - if device is None: - device = inputs.device - if sw_device is None: - sw_device = inputs.device - + temp_meta = None + if isinstance(inputs, MetaTensor): + temp_meta = MetaTensor([]).copy_meta_from(inputs, copy_attr=False) + inputs = convert_data_type(inputs, torch.Tensor, wrap_sequence=True)[0] roi_size = fall_back_tuple(roi_size, image_size_) + # in case that image size is smaller than roi size image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims)) pad_size = [] @@ -142,16 +163,29 @@ def sliding_window_inference( diff = max(roi_size[k - 2] - inputs.shape[k], 0) half = diff // 2 pad_size.extend([half, diff - half]) - - if max(pad_size) > 0: + if any(pad_size): inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval) + # Store all slices scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) + slices = dense_patch_slices(image_size, roi_size, scan_interval, return_slice=not buffered) - # Store all slices in list - slices = dense_patch_slices(image_size, roi_size, scan_interval) num_win = len(slices) # number of windows per image total_slices = num_win * batch_size # total number of windows + windows_range: Iterable + if not buffered: + non_blocking = False + windows_range = range(0, total_slices, sw_batch_size) + else: + slices, n_per_batch, b_slices, windows_range = _create_buffered_slices( + slices, batch_size, sw_batch_size, buffer_dim, buffer_steps + ) + non_blocking, _ss = torch.cuda.is_available(), -1 + for x in b_slices[:n_per_batch]: + if x[1] < _ss: # detect overlapping slices + non_blocking = False + break + _ss = x[2] # Create window-level importance map valid_patch_size = get_valid_patch_size(image_size, roi_size) @@ -159,152 +193,166 @@ def sliding_window_inference( importance_map_ = roi_weight_map else: try: + valid_p_size = ensure_tuple(valid_patch_size) importance_map_ = compute_importance_map( - valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device + valid_p_size, mode=mode, sigma_scale=sigma_scale, device=sw_device, dtype=compute_dtype ) - except BaseException as e: + if len(importance_map_.shape) == num_spatial_dims and not process_fn: + importance_map_ = importance_map_[None, None] # adds batch, channel dimensions + except Exception as e: raise RuntimeError( + f"patch size {valid_p_size}, mode={mode}, sigma_scale={sigma_scale}, device={device}\n" "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'." ) from e - importance_map_ = convert_data_type(importance_map_, torch.Tensor, device, compute_dtype)[0] - - # handle non-positive weights - min_non_zero = max(torch.min(importance_map_).item(), 1e-3) - importance_map_ = torch.clamp_(importance_map_.to(torch.float32), min=min_non_zero).to(compute_dtype) - - # Perform predictions - dict_key, output_image_list, count_map_list = None, [], [] - _initialized_ss = -1 - is_tensor_output = True # whether the predictor's output is a tensor (instead of dict/tuple) + importance_map_ = convert_data_type(importance_map_, torch.Tensor, device=sw_device, dtype=compute_dtype)[0] + # stores output and count map + output_image_list, count_map_list, sw_device_buffer, b_s, b_i = [], [], [], 0, 0 # type: ignore # for each patch - for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size): - slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices)) + for slice_g in tqdm(windows_range) if progress else windows_range: + slice_range = range(slice_g, min(slice_g + sw_batch_size, b_slices[b_s][0] if buffered else total_slices)) unravel_slice = [ - [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) + [slice(idx // num_win, idx // num_win + 1), slice(None)] + list(slices[idx % num_win]) for idx in slice_range ] - window_data = torch.cat( - [convert_data_type(inputs[win_slice], torch.Tensor)[0] for win_slice in unravel_slice] - ).to(sw_device) - seg_prob_out = predictor(window_data, *args, **kwargs) # batched patch segmentation - - # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory. - seg_prob_tuple: tuple[torch.Tensor, ...] - if isinstance(seg_prob_out, torch.Tensor): - seg_prob_tuple = (seg_prob_out,) - elif isinstance(seg_prob_out, Mapping): - if dict_key is None: - dict_key = sorted(seg_prob_out.keys()) # track predictor's output keys - seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key) - is_tensor_output = False + if sw_batch_size > 1: + win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) else: - seg_prob_tuple = ensure_tuple(seg_prob_out) - is_tensor_output = False + win_data = inputs[unravel_slice[0]].to(sw_device) + seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch + # convert seg_prob_out to tuple seg_tuple, this does not allocate new memory. + dict_keys, seg_tuple = _flatten_struct(seg_prob_out) if process_fn: - seg_prob_tuple, importance_map = process_fn(seg_prob_tuple, window_data, importance_map_) + seg_tuple, w_t = process_fn(seg_tuple, win_data, importance_map_) + else: + w_t = importance_map_ + if len(w_t.shape) == num_spatial_dims: + w_t = w_t[None, None] + w_t = w_t.to(dtype=compute_dtype, device=sw_device) + if buffered: + c_start, c_end = b_slices[b_s][1:] + if not sw_device_buffer: + k = seg_tuple[0].shape[1] # len(seg_tuple) > 1 is currently ignored + sp_size = list(image_size) + sp_size[buffer_dim] = c_end - c_start + sw_device_buffer = [torch.zeros(size=[1, k, *sp_size], dtype=compute_dtype, device=sw_device)] + for p, s in zip(seg_tuple[0], unravel_slice): + offset = s[buffer_dim + 2].start - c_start + s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim]) + s[0] = slice(0, 1) + sw_device_buffer[0][s] += p * w_t + b_i += len(unravel_slice) + if b_i < b_slices[b_s][0]: + continue else: - importance_map = importance_map_ - - # for each output in multi-output list - for ss, seg_prob in enumerate(seg_prob_tuple): - seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN - - # compute zoom scale: out_roi_size/in_roi_size - zoom_scale = [] - for axis, (img_s_i, out_w_i, in_w_i) in enumerate( - zip(image_size, seg_prob.shape[2:], window_data.shape[2:]) - ): - _scale = out_w_i / float(in_w_i) - if not (img_s_i * _scale).is_integer(): - warnings.warn( - f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial " - f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs." - ) - zoom_scale.append(_scale) - - if _initialized_ss < ss: # init. the ss-th buffer at the first iteration - # construct multi-resolution outputs - output_classes = seg_prob.shape[1] - output_shape = [batch_size, output_classes] + [ - int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale) - ] + sw_device_buffer = list(seg_tuple) + + for ss in range(len(sw_device_buffer)): + b_shape = sw_device_buffer[ss].shape + seg_chns, seg_shape = b_shape[1], b_shape[2:] + z_scale = None + if not buffered and seg_shape != roi_size: + z_scale = [out_w_i / float(in_w_i) for out_w_i, in_w_i in zip(seg_shape, roi_size)] + w_t = F.interpolate(w_t, seg_shape, mode=_nearest_mode) + if len(output_image_list) <= ss: + output_shape = [batch_size, seg_chns] + output_shape += [int(_i * _z) for _i, _z in zip(image_size, z_scale)] if z_scale else list(image_size) # allocate memory to store the full output and the count for overlapping parts - output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device=device)) + new_tensor: Callable = torch.empty if non_blocking else torch.zeros # type: ignore + output_image_list.append(new_tensor(output_shape, dtype=compute_dtype, device=device)) count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device)) - _initialized_ss += 1 - - # resizing the importance_map - resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False) - - # store the result in the proper location of the full output. Apply weights from importance map. - for idx, original_idx in zip(slice_range, unravel_slice): - # zoom roi - original_idx_zoom = list(original_idx) # 4D for 2D image, 5D for 3D image - for axis in range(2, len(original_idx_zoom)): - zoomed_start = original_idx[axis].start * zoom_scale[axis - 2] - zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2] - if not zoomed_start.is_integer() or (not zoomed_end.is_integer()): - warnings.warn( - f"For axis-{axis-2} of output[{ss}], the output roi range is not int. " - f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). " - f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. " - f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n" - f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. " - "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works." - ) - original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None) - importance_map_zoom = ( - resizer(importance_map.unsqueeze(0))[0].to(compute_dtype) - if seg_prob.shape[2:] != importance_map.shape - else importance_map.to(compute_dtype) - ) - # store results and weights - output_image_list[ss][original_idx_zoom] += importance_map_zoom * seg_prob[idx - slice_g] - count_map_list[ss][original_idx_zoom] += ( - importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape) - ) + w_t_ = w_t.to(device) + for __s in slices: + if z_scale is not None: + __s = tuple(slice(int(_si.start * z_s), int(_si.stop * z_s)) for _si, z_s in zip(__s, z_scale)) + count_map_list[-1][(slice(None), slice(None), *__s)] += w_t_ + if buffered: + o_slice = [slice(None)] * len(inputs.shape) + o_slice[buffer_dim + 2] = slice(c_start, c_end) + img_b = b_s // n_per_batch # image batch index + o_slice[0] = slice(img_b, img_b + 1) + if non_blocking: + output_image_list[0][o_slice].copy_(sw_device_buffer[0], non_blocking=non_blocking) + else: + output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device) + else: + sw_device_buffer[ss] *= w_t + sw_device_buffer[ss] = sw_device_buffer[ss].to(device) + _compute_coords(sw_batch_size, unravel_slice, z_scale, output_image_list[ss], sw_device_buffer[ss]) + sw_device_buffer = [] + if buffered: + b_s += 1 + + if non_blocking: + torch.cuda.current_stream().synchronize() # account for any overlapping sections for ss in range(len(output_image_list)): - output_image_list[ss] = output_image_list[ss] - _map = count_map_list.pop(0) - for _i in range(output_image_list[ss].shape[1]): - output_image_list[ss][:, _i : _i + 1, ...] /= _map - output_image_list[ss] = output_image_list[ss].to(compute_dtype) + output_image_list[ss] /= count_map_list.pop(0) # remove padding if image_size smaller than roi_size - for ss, output_i in enumerate(output_image_list): - zoom_scale = [ - seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size) - ] - - final_slicing: list[slice] = [] - for sp in range(num_spatial_dims): - slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2]) - slice_dim = slice( - int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])), - int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])), - ) - final_slicing.insert(0, slice_dim) - while len(final_slicing) < len(output_i.shape): - final_slicing.insert(0, slice(None)) - output_image_list[ss] = output_i[final_slicing] - - if dict_key is not None: # if output of predictor is a dict - final_output = dict(zip(dict_key, output_image_list)) - else: - final_output = tuple(output_image_list) # type: ignore - final_output = final_output[0] if is_tensor_output else final_output - - if isinstance(inputs, MetaTensor): - final_output = convert_to_dst_type(final_output, inputs, device=device)[0] # type: ignore - return final_output + if any(pad_size): + for ss, output_i in enumerate(output_image_list): + zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)] + final_slicing: list[slice] = [] + for sp in range(num_spatial_dims): + si = num_spatial_dims - sp - 1 + slice_dim = slice( + int(round(pad_size[sp * 2] * zoom_scale[si])), + int(round((pad_size[sp * 2] + image_size_[si]) * zoom_scale[si])), + ) + final_slicing.insert(0, slice_dim) + output_image_list[ss] = output_i[(slice(None), slice(None), *final_slicing)] + + final_output = _pack_struct(output_image_list, dict_keys) + final_output = convert_to_dst_type(final_output, inputs, device=device)[0] # type: ignore + if temp_meta is not None: + final_output = MetaTensor(final_output).copy_meta_from(temp_meta) + return final_output # type: ignore + + +def _create_buffered_slices(slices, batch_size, sw_batch_size, buffer_dim, buffer_steps): + """rearrange slices for buffering""" + slices_np = np.asarray(slices) + slices_np = slices_np[np.argsort(slices_np[:, buffer_dim, 0], kind="mergesort")] + slices = [tuple(slice(c[0], c[1]) for c in i) for i in slices_np] + slices_np = slices_np[:, buffer_dim] + + _, _, _b_lens = np.unique(slices_np[:, 0], return_counts=True, return_index=True) + b_ends = np.cumsum(_b_lens).tolist() # possible buffer flush boundaries + x = [0, *b_ends][:: min(len(b_ends), int(buffer_steps))] # type: ignore + if x[-1] < b_ends[-1]: + x.append(b_ends[-1]) + n_per_batch = len(x) - 1 + windows_range = [ + range(b * x[-1] + x[i], b * x[-1] + x[i + 1], sw_batch_size) + for b in range(batch_size) + for i in range(n_per_batch) + ] + b_slices = [] + for _s, _r in enumerate(windows_range): + s_s = slices_np[windows_range[_s - 1].stop % len(slices) if _s > 0 else 0, 0] + s_e = slices_np[(_r.stop - 1) % len(slices), 1] + b_slices.append((_r.stop, s_s, s_e)) # buffer index, slice start, slice end + windows_range = itertools.chain(*windows_range) # type: ignore + return slices, n_per_batch, b_slices, windows_range + + +def _compute_coords(sw, coords, z_scale, out, patch): + """sliding window batch spatial scaling indexing for multi-resolution outputs.""" + for original_idx, p in zip(coords, patch): + idx_zm = list(original_idx) # 4D for 2D image, 5D for 3D image + if z_scale: + for axis in range(2, len(idx_zm)): + idx_zm[axis] = slice( + int(original_idx[axis].start * z_scale[axis - 2]), int(original_idx[axis].stop * z_scale[axis - 2]) + ) + out[idx_zm] += p def _get_scan_interval( - image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float + image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: Sequence[float] ) -> tuple[int, ...]: """ Compute scan interval according to the image size, roi size and overlap. @@ -313,15 +361,36 @@ def _get_scan_interval( """ if len(image_size) != num_spatial_dims: - raise ValueError("image coord different from spatial dims.") + raise ValueError(f"len(image_size) {len(image_size)} different from spatial dims {num_spatial_dims}.") if len(roi_size) != num_spatial_dims: - raise ValueError("roi coord different from spatial dims.") + raise ValueError(f"len(roi_size) {len(roi_size)} different from spatial dims {num_spatial_dims}.") scan_interval = [] - for i in range(num_spatial_dims): + for i, o in zip(range(num_spatial_dims), overlap): if roi_size[i] == image_size[i]: scan_interval.append(int(roi_size[i])) else: - interval = int(roi_size[i] * (1 - overlap)) + interval = int(roi_size[i] * (1 - o)) scan_interval.append(interval if interval > 0 else 1) return tuple(scan_interval) + + +def _flatten_struct(seg_out): + dict_keys = None + seg_probs: tuple[torch.Tensor, ...] + if isinstance(seg_out, torch.Tensor): + seg_probs = (seg_out,) + elif isinstance(seg_out, Mapping): + dict_keys = sorted(seg_out.keys()) # track predictor's output keys + seg_probs = tuple(seg_out[k] for k in dict_keys) + else: + seg_probs = ensure_tuple(seg_out) # type: ignore + return dict_keys, seg_probs + + +def _pack_struct(seg_out, dict_keys=None): + if dict_keys is not None: + return dict(zip(dict_keys, seg_out)) + if isinstance(seg_out, (list, tuple)) and len(seg_out) == 1: + return seg_out[0] + return ensure_tuple(seg_out) diff --git a/monai/utils/module.py b/monai/utils/module.py index b72e3ff139..fcd5e04145 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -12,6 +12,7 @@ from __future__ import annotations import enum +import functools import os import pdb import re @@ -508,6 +509,7 @@ def get_package_version(dep_name, default="NOT INSTALLED or UNKNOWN VERSION."): return default +@functools.lru_cache(None) def get_torch_version_tuple(): """ Returns: @@ -562,6 +564,7 @@ def _try_cast(val: str) -> int | str: return True +@functools.lru_cache(None) def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: str | None = None) -> bool: """ Compute whether the current pytorch version is after or equal to the specified version. diff --git a/tests/test_sliding_window_hovernet_inference.py b/tests/test_sliding_window_hovernet_inference.py index 0dc2216c22..8f7f8346cc 100644 --- a/tests/test_sliding_window_hovernet_inference.py +++ b/tests/test_sliding_window_hovernet_inference.py @@ -232,6 +232,8 @@ def compute(data, test1, test2): has_tqdm, None, None, + None, + 0, t1, test2=t2, ) diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index 5f07084927..117ad341c5 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -21,7 +21,7 @@ from monai.data.utils import list_data_collate from monai.inferers import SlidingWindowInferer, sliding_window_inference from monai.utils import optional_import -from tests.utils import TEST_TORCH_AND_META_TENSORS, skip_if_no_cuda +from tests.utils import TEST_TORCH_AND_META_TENSORS, skip_if_no_cuda, test_is_quick _, has_tqdm = optional_import("tqdm") @@ -45,8 +45,56 @@ [(5, 3, 16, 15, 7), (4, 1, 7), 3, 0.25, "constant", torch.device("cpu:0")], # 3D small roi ] +_devices = [["cpu", "cuda:0"]] if torch.cuda.is_available() else [["cpu"]] +_windows = [ + [(2, 3, 10, 11), (7, 10), 0.8, 5], + [(2, 3, 10, 11), (15, 12), 0, 2], + [(2, 3, 10, 11), (10, 11), 0, 3], + [(2, 3, 511, 237), (96, 80), 0.4, 5], + [(2, 3, 512, 245), (96, 80), 0, 5], + [(2, 3, 512, 245), (512, 80), 0.125, 5], + [(2, 3, 10, 11, 12), (7, 8, 10), 0.2, 2], +] +if not test_is_quick(): + _windows += [ + [(2, 1, 125, 512, 200), (96, 97, 98), (0.4, 0.32, 0), 20], + [(2, 1, 10, 512, 200), (96, 97, 98), (0.4, 0.12, 0), 21], + [(2, 3, 100, 100, 200), (50, 50, 100), 0, 8], + ] + +BUFFER_CASES: list = [] +for x in _windows: + for s in (1, 3, 4): + for d in (-1, 0, 1): + BUFFER_CASES.extend([x, s, d, dev] for dev in itertools.product(*_devices * 3)) + class TestSlidingWindowInference(unittest.TestCase): + @parameterized.expand(BUFFER_CASES) + def test_buffers(self, size_params, buffer_steps, buffer_dim, device_params): + def mult_two(patch, *args, **kwargs): + return 2.0 * patch + + img_size, roi_size, overlap, sw_batch_size = size_params + img_device, device, sw_device = device_params + dtype = [torch.float, torch.double][roi_size[0] % 2] # test different input dtype + mode = ["constant", "gaussian"][img_size[1] % 2] + image = torch.randint(0, 255, size=img_size, dtype=dtype, device=img_device) + sw = sliding_window_inference( + image, + roi_size, + sw_batch_size, + mult_two, + overlap, + mode=mode, + sw_device=sw_device, + device=device, + buffer_steps=buffer_steps, + buffer_dim=buffer_dim, + ) + max_diff = torch.max(torch.abs(image.to(sw) - 0.5 * sw)).item() + self.assertGreater(0.001, max_diff) + @parameterized.expand(TEST_CASES) def test_sliding_window_default(self, image_shape, roi_shape, sw_batch_size, overlap, mode, device): n_total = np.prod(image_shape) @@ -244,6 +292,8 @@ def compute(data, test1, test2): has_tqdm, None, None, + None, + 0, t1, test2=t2, )