From 080f57cb99f4fa090d377ad87b228a9c6d046a58 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 27 Mar 2023 20:38:54 +0100 Subject: [PATCH 01/30] staging Signed-off-by: Wenqi Li --- monai/data/utils.py | 12 +- monai/inferers/inferer.py | 3 +- monai/inferers/utils.py | 252 +++++++++++++++++++++----------------- 3 files changed, 153 insertions(+), 114 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 2c035afb3f..872cca1da9 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -164,7 +164,7 @@ def iter_patch_slices( def dense_patch_slices( - image_size: Sequence[int], patch_size: Sequence[int], scan_interval: Sequence[int] + image_size: Sequence[int], patch_size: Sequence[int], scan_interval: Sequence[int], return_slice: bool = True ) -> list[tuple[slice, ...]]: """ Enumerate all slices defining ND patches of size `patch_size` from an `image_size` input image. @@ -173,6 +173,7 @@ def dense_patch_slices( image_size: dimensions of image to iterate over patch_size: size of patches to generate slices scan_interval: dense patch sampling interval + return_slice: whether to return a list of slices (or tuples of indices), defaults to True Returns: a list of slice objects defining each patch @@ -200,7 +201,9 @@ def dense_patch_slices( dim_starts.append(start_idx) starts.append(dim_starts) out = np.asarray([x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T - return [tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out] + if return_slice: + return [tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out] + return [tuple((s, s + patch_size[d]) for d, s in enumerate(x)) for x in out] def iter_patch_position( @@ -1058,6 +1061,7 @@ def compute_importance_map( mode: BlendMode | str = BlendMode.CONSTANT, sigma_scale: Sequence[float] | float = 0.125, device: torch.device | int | str = "cpu", + dtype: torch.dtype | str | None = torch.float32, ) -> torch.Tensor: """Get importance map for different weight modes. @@ -1072,6 +1076,7 @@ def compute_importance_map( sigma_scale: Sigma_scale to calculate sigma for each dimension (sigma = sigma_scale * dim_size). Used for gaussian mode only. device: Device to put importance map on. + dtype: Data type of the output importance map. Raises: ValueError: When ``mode`` is not one of ["constant", "gaussian"]. @@ -1098,6 +1103,9 @@ def compute_importance_map( raise ValueError( f"Unsupported mode: {mode}, available options are [{BlendMode.CONSTANT}, {BlendMode.CONSTANT}]." ) + # handle non-positive weights + min_non_zero = max(torch.min(importance_map).item(), 1e-3) + importance_map = torch.clamp_(importance_map.to(torch.float), min=min_non_zero).to(dtype) return importance_map diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 03b5e7a75f..012a238c89 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -415,7 +415,8 @@ def __init__( warnings.warn("cache_roi_weight_map=True, but cache is not created. (dynamic roi_size?)") except BaseException as e: raise RuntimeError( - "Seems to be OOM. Please try smaller roi_size, or use mode='constant' instead of mode='gaussian'. " + f"roi size {self.roi_size}, mode={mode}, sigma_scale={sigma_scale}, device={device}\n" + "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'." ) from e def __call__( diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index c4405911d0..b23da92820 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -13,8 +13,10 @@ import warnings from collections.abc import Callable, Mapping, Sequence +from itertools import chain from typing import Any +import numpy as np import torch import torch.nn.functional as F @@ -116,14 +118,18 @@ def sliding_window_inference( args: optional args to be passed to ``predictor``. kwargs: optional keyword args to be passed to ``predictor``. + - buffer_steps: the number of sliding window iterations before writing the outputs to ``device``. + Note: - input must be channel-first and have a batch dim, supports N-D sliding window. """ + b_steps = kwargs.pop("buffer_steps", None) + b_plane = kwargs.pop("buffer_plane", 0) compute_dtype = inputs.dtype num_spatial_dims = len(inputs.shape) - 2 if overlap < 0 or overlap >= 1: - raise ValueError("overlap must be >= 0 and < 1.") + raise ValueError(f"overlap must be >= 0 and < 1, got {overlap}.") # determine image spatial size and batch size # Note: all input images must have the same image size and batch size @@ -134,7 +140,12 @@ def sliding_window_inference( if sw_device is None: sw_device = inputs.device + metadict = None + if isinstance(inputs, MetaTensor): + metadict = inputs.meta.copy() + inputs = convert_data_type(inputs, torch.Tensor, wrap_sequence=True)[0] roi_size = fall_back_tuple(roi_size, image_size_) + # in case that image size is smaller than roi size image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims)) pad_size = [] @@ -142,16 +153,29 @@ def sliding_window_inference( diff = max(roi_size[k - 2] - inputs.shape[k], 0) half = diff // 2 pad_size.extend([half, diff - half]) - - if max(pad_size) > 0: + if any(pad_size): inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval) + # Store all slices scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) + slices = dense_patch_slices(image_size, roi_size, scan_interval, return_slice=False) + slices_np = np.asarray(slices) + slices_np = slices_np[np.argsort(slices_np[:, b_plane, 0], kind="mergesort")] + slices = [tuple(slice(c[0], c[1]) for c in i) for i in slices_np] + _, p_id, buffer_lens = np.unique(slices_np[:, b_plane, 0], return_counts=True, return_index=True) + b_se = [tuple(slices_np[i][b_plane]) for i in p_id] # buffer start & end along the b_plane + buffer_lens = np.repeat(buffer_lens, batch_size) + b_ends = np.cumsum(buffer_lens) - # Store all slices in list - slices = dense_patch_slices(image_size, roi_size, scan_interval) num_win = len(slices) # number of windows per image total_slices = num_win * batch_size # total number of windows + windows_range = range(0, total_slices, sw_batch_size) + if b_steps is not None: + windows_range, s = [], 0 + for b in buffer_lens: + windows_range.append(range(s, s + b, sw_batch_size)) + s = s + b + windows_range = chain(*windows_range) # Create window-level importance map valid_patch_size = get_valid_patch_size(image_size, roi_size) @@ -159,148 +183,133 @@ def sliding_window_inference( importance_map_ = roi_weight_map else: try: + valid_p_size = ensure_tuple(valid_patch_size) importance_map_ = compute_importance_map( - valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device + valid_p_size, mode=mode, sigma_scale=sigma_scale, device=sw_device, dtype=compute_dtype ) except BaseException as e: raise RuntimeError( + f"patch size {valid_p_size}, mode={mode}, sigma_scale={sigma_scale}, device={device}\n" "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'." ) from e - importance_map_ = convert_data_type(importance_map_, torch.Tensor, device, compute_dtype)[0] - - # handle non-positive weights - min_non_zero = max(torch.min(importance_map_).item(), 1e-3) - importance_map_ = torch.clamp_(importance_map_.to(torch.float32), min=min_non_zero).to(compute_dtype) - - # Perform predictions - dict_key, output_image_list, count_map_list = None, [], [] - _initialized_ss = -1 - is_tensor_output = True # whether the predictor's output is a tensor (instead of dict/tuple) + importance_map_ = convert_data_type(importance_map_, torch.Tensor, device=sw_device, dtype=compute_dtype)[0] + # stores output and count map + output_image_list, count_map_list, sw_device_buffer, _initialized_ss, b_s, b_i = [], [], [], -1, 0, 0 # for each patch - for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size): - slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices)) + for slice_g in tqdm(windows_range) if progress else windows_range: + slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices if b_steps is None else b_ends[b_s])) unravel_slice = [ [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) for idx in slice_range ] - window_data = torch.cat( - [convert_data_type(inputs[win_slice], torch.Tensor)[0] for win_slice in unravel_slice] - ).to(sw_device) - seg_prob_out = predictor(window_data, *args, **kwargs) # batched patch segmentation - - # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory. - seg_prob_tuple: tuple[torch.Tensor, ...] - if isinstance(seg_prob_out, torch.Tensor): - seg_prob_tuple = (seg_prob_out,) - elif isinstance(seg_prob_out, Mapping): - if dict_key is None: - dict_key = sorted(seg_prob_out.keys()) # track predictor's output keys - seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key) - is_tensor_output = False - else: - seg_prob_tuple = ensure_tuple(seg_prob_out) - is_tensor_output = False + win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) + seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch + # convert seg_prob_out to tuple seg_tuple, this does not allocate new memory. + dict_keys, seg_tuple = _flatten_struct(seg_prob_out) if process_fn: - seg_prob_tuple, importance_map = process_fn(seg_prob_tuple, window_data, importance_map_) + seg_tuple, importance_map = process_fn(seg_tuple, win_data, importance_map_) else: importance_map = importance_map_ - # for each output in multi-output list - for ss, seg_prob in enumerate(seg_prob_tuple): - seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN - - # compute zoom scale: out_roi_size/in_roi_size - zoom_scale = [] - for axis, (img_s_i, out_w_i, in_w_i) in enumerate( - zip(image_size, seg_prob.shape[2:], window_data.shape[2:]) - ): - _scale = out_w_i / float(in_w_i) - if not (img_s_i * _scale).is_integer(): - warnings.warn( - f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial " - f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs." - ) - zoom_scale.append(_scale) - + if b_steps is not None: + if len(seg_tuple) > 1: + warnings.warn("Multiple outputs are not supported with buffer_steps, only the first output is used.") + if not sw_device_buffer: + k = seg_tuple[0].shape[1] + sp_size = list(image_size) + sp_size[b_plane] = roi_size[b_plane] # one step roi + sw_device_buffer = [torch.zeros(size=[1, k, *sp_size], dtype=compute_dtype, device=sw_device)] + importance_map = importance_map.to(dtype=compute_dtype, device=sw_device) + b_i = 0 + for p, s in zip(seg_tuple[0], unravel_slice): + offset = s[b_plane + 2].start - b_se[b_s % len(b_se)][0] + s[b_plane + 2] = slice(offset, offset + roi_size[b_plane]) + s[0] = slice(0, 1) + sw_device_buffer[0][s] += p * importance_map + b_i += 1 + if b_i < buffer_lens[b_s]: + continue + else: + sw_device_buffer = list(seg_tuple) + + for ss in range(len(sw_device_buffer)): + b_shape = sw_device_buffer[ss].shape + seg_chns, seg_shape = b_shape[1], b_shape[2:] + if b_steps is None and seg_shape != roi_size: + z_scale = [out_w_i / float(in_w_i) for out_w_i, in_w_i in zip(seg_shape, roi_size)] + else: + z_scale = None + if seg_shape != importance_map.shape and b_steps is None: # resizing the importance_map + resizer = Resize(spatial_size=seg_shape, mode="nearest", anti_aliasing=False) + w_t = resizer(importance_map.unsqueeze(0))[None].to(dtype=compute_dtype, device=sw_device) + else: + w_t = importance_map[None, None].to(dtype=compute_dtype, device=sw_device) if _initialized_ss < ss: # init. the ss-th buffer at the first iteration # construct multi-resolution outputs - output_classes = seg_prob.shape[1] - output_shape = [batch_size, output_classes] + [ - int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale) - ] + output_shape = [batch_size, seg_chns] + output_shape += [int(_i * _z) for _i, _z in zip(image_size, z_scale)] if z_scale else list(image_size) # allocate memory to store the full output and the count for overlapping parts output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device=device)) count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device)) + w_t = w_t.to(device) + for __s in slices: + if z_scale is not None: + __s = [slice(int(_si.start * z_s), int(_si.stop * z_s)) for _si, z_s in zip(__s, z_scale)] + count_map_list[-1][(slice(None), slice(None), *__s)] += w_t _initialized_ss += 1 - - # resizing the importance_map - resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False) - - # store the result in the proper location of the full output. Apply weights from importance map. - for idx, original_idx in zip(slice_range, unravel_slice): - # zoom roi - original_idx_zoom = list(original_idx) # 4D for 2D image, 5D for 3D image - for axis in range(2, len(original_idx_zoom)): - zoomed_start = original_idx[axis].start * zoom_scale[axis - 2] - zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2] - if not zoomed_start.is_integer() or (not zoomed_end.is_integer()): - warnings.warn( - f"For axis-{axis-2} of output[{ss}], the output roi range is not int. " - f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). " - f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. " - f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n" - f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. " - "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works." - ) - original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None) - importance_map_zoom = ( - resizer(importance_map.unsqueeze(0))[0].to(compute_dtype) - if seg_prob.shape[2:] != importance_map.shape - else importance_map.to(compute_dtype) - ) - # store results and weights - output_image_list[ss][original_idx_zoom] += importance_map_zoom * seg_prob[idx - slice_g] - count_map_list[ss][original_idx_zoom] += ( - importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape) - ) + w_t = w_t.to(sw_device) + if b_steps is not None: + o_slice = [slice(None)] * len(inputs.shape) + o_slice[b_plane + 2] = slice(b_se[b_s % len(b_se)][0], b_se[b_s % len(b_se)][1]) + img_b = int(b_s / len(b_se)) # image batch index + o_slice[0] = slice(img_b, img_b + 1) + output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device) + continue + sw_t = sw_device_buffer[ss] + sw_t *= w_t[0, 0] + sw_t = sw_t.to(device) + _compute_coords(sw_batch_size, unravel_slice, z_scale, output_image_list[ss], sw_t) + b_s += 1 + sw_device_buffer, b_i = None, 0 # account for any overlapping sections for ss in range(len(output_image_list)): - output_image_list[ss] = output_image_list[ss] - _map = count_map_list.pop(0) - for _i in range(output_image_list[ss].shape[1]): - output_image_list[ss][:, _i : _i + 1, ...] /= _map - output_image_list[ss] = output_image_list[ss].to(compute_dtype) + output_image_list[ss] /= count_map_list.pop(0) # remove padding if image_size smaller than roi_size for ss, output_i in enumerate(output_image_list): - zoom_scale = [ - seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size) - ] - + zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)] final_slicing: list[slice] = [] for sp in range(num_spatial_dims): - slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2]) + si = num_spatial_dims - sp - 1 slice_dim = slice( - int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])), - int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])), + int(round(pad_size[sp * 2] * zoom_scale[si])), + int(round((pad_size[sp * 2] + image_size_[si]) * zoom_scale[si])), ) final_slicing.insert(0, slice_dim) while len(final_slicing) < len(output_i.shape): final_slicing.insert(0, slice(None)) output_image_list[ss] = output_i[final_slicing] - if dict_key is not None: # if output of predictor is a dict - final_output = dict(zip(dict_key, output_image_list)) - else: - final_output = tuple(output_image_list) # type: ignore - final_output = final_output[0] if is_tensor_output else final_output - - if isinstance(inputs, MetaTensor): - final_output = convert_to_dst_type(final_output, inputs, device=device)[0] # type: ignore - return final_output + final_output = _pack_struct(output_image_list, dict_keys) + final_output = convert_to_dst_type(final_output, inputs, device=device)[0] # type: ignore + if metadict is not None: + final_output = MetaTensor(final_output, meta=metadict) + return final_output # type: ignore + + +def _compute_coords(sw, coords, z_scale, out, patch): + """sliding window batch spatial scaling indexing for multi-resolution outputs.""" + for original_idx, p in zip(coords, patch): + idx_zm = list(original_idx) # 4D for 2D image, 5D for 3D image + if z_scale: + for axis in range(2, len(idx_zm)): + idx_zm[axis] = slice( + int(original_idx[axis].start * z_scale[axis - 2]), int(original_idx[axis].stop * z_scale[axis - 2]) + ) + out[idx_zm] += p def _get_scan_interval( @@ -313,9 +322,9 @@ def _get_scan_interval( """ if len(image_size) != num_spatial_dims: - raise ValueError("image coord different from spatial dims.") + raise ValueError(f"len(image_size) {len(image_size)} different from spatial dims {num_spatial_dims}.") if len(roi_size) != num_spatial_dims: - raise ValueError("roi coord different from spatial dims.") + raise ValueError(f"len(roi_size) {len(roi_size)} different from spatial dims {num_spatial_dims}.") scan_interval = [] for i in range(num_spatial_dims): @@ -325,3 +334,24 @@ def _get_scan_interval( interval = int(roi_size[i] * (1 - overlap)) scan_interval.append(interval if interval > 0 else 1) return tuple(scan_interval) + + +def _flatten_struct(seg_out): + dict_keys = None + seg_probs: tuple[torch.Tensor, ...] + if isinstance(seg_out, torch.Tensor): + seg_probs = (seg_out,) + elif isinstance(seg_out, Mapping): + dict_keys = sorted(seg_out.keys()) # track predictor's output keys + seg_probs = tuple(seg_out[k] for k in dict_keys) + else: + seg_probs = ensure_tuple(seg_out) # type: ignore + return dict_keys, seg_probs + + +def _pack_struct(seg_out, dict_keys=None): + if dict_keys is not None: + return dict(zip(dict_keys, seg_out)) + if isinstance(seg_out, (list, tuple)) and len(seg_out) == 1: + return seg_out[0] + return ensure_tuple(seg_out) From 60037e05ea7ed436926eb857759aa35be00de998 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 30 Mar 2023 16:50:13 -0400 Subject: [PATCH 02/30] opt Signed-off-by: Wenqi Li --- monai/inferers/utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index b23da92820..c8e879f305 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -187,7 +187,7 @@ def sliding_window_inference( importance_map_ = compute_importance_map( valid_p_size, mode=mode, sigma_scale=sigma_scale, device=sw_device, dtype=compute_dtype ) - except BaseException as e: + except Exception as e: raise RuntimeError( f"patch size {valid_p_size}, mode={mode}, sigma_scale={sigma_scale}, device={device}\n" "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'." @@ -195,7 +195,7 @@ def sliding_window_inference( importance_map_ = convert_data_type(importance_map_, torch.Tensor, device=sw_device, dtype=compute_dtype)[0] # stores output and count map - output_image_list, count_map_list, sw_device_buffer, _initialized_ss, b_s, b_i = [], [], [], -1, 0, 0 + output_image_list, count_map_list, sw_device_buffer, b_s, b_i = [], [], [], 0, 0 # for each patch for slice_g in tqdm(windows_range) if progress else windows_range: slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices if b_steps is None else b_ends[b_s])) @@ -246,7 +246,7 @@ def sliding_window_inference( w_t = resizer(importance_map.unsqueeze(0))[None].to(dtype=compute_dtype, device=sw_device) else: w_t = importance_map[None, None].to(dtype=compute_dtype, device=sw_device) - if _initialized_ss < ss: # init. the ss-th buffer at the first iteration + if len(output_image_list) <= ss: # init. the ss-th buffer at the first iteration # construct multi-resolution outputs output_shape = [batch_size, seg_chns] output_shape += [int(_i * _z) for _i, _z in zip(image_size, z_scale)] if z_scale else list(image_size) @@ -258,7 +258,6 @@ def sliding_window_inference( if z_scale is not None: __s = [slice(int(_si.start * z_s), int(_si.stop * z_s)) for _si, z_s in zip(__s, z_scale)] count_map_list[-1][(slice(None), slice(None), *__s)] += w_t - _initialized_ss += 1 w_t = w_t.to(sw_device) if b_steps is not None: o_slice = [slice(None)] * len(inputs.shape) From 1e144eb921434bfb67a9335875d9386f1793c55e Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 30 Mar 2023 20:25:13 -0400 Subject: [PATCH 03/30] valid Signed-off-by: Wenqi Li --- monai/inferers/utils.py | 108 +++++++++++++++++++++------------------- 1 file changed, 58 insertions(+), 50 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index c8e879f305..f616ff1584 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -11,9 +11,8 @@ from __future__ import annotations -import warnings from collections.abc import Callable, Mapping, Sequence -from itertools import chain +import itertools from typing import Any import numpy as np @@ -119,26 +118,31 @@ def sliding_window_inference( kwargs: optional keyword args to be passed to ``predictor``. - buffer_steps: the number of sliding window iterations before writing the outputs to ``device``. + default is None, no buffer. + - buffer_dim: the dimension along which the buffer are created, default is 0. Note: - input must be channel-first and have a batch dim, supports N-D sliding window. """ b_steps = kwargs.pop("buffer_steps", None) - b_plane = kwargs.pop("buffer_plane", 0) - compute_dtype = inputs.dtype + b_plane = kwargs.pop("buffer_dim", 0) + buffered = b_steps is not None and b_steps > 0 num_spatial_dims = len(inputs.shape) - 2 + if buffered: + if (b_plane < -num_spatial_dims + 1 or b_plane > num_spatial_dims): + raise ValueError(f"buffer_dim must be in [{-num_spatial_dims + 1}, {num_spatial_dims}], got {b_plane}.") + if b_steps <= 0: + raise ValueError(f"buffer_steps must be >= 0, got {b_steps}.") if overlap < 0 or overlap >= 1: raise ValueError(f"overlap must be >= 0 and < 1, got {overlap}.") + compute_dtype = inputs.dtype # determine image spatial size and batch size # Note: all input images must have the same image size and batch size batch_size, _, *image_size_ = inputs.shape - - if device is None: - device = inputs.device - if sw_device is None: - sw_device = inputs.device + device = device or inputs.device + sw_device = sw_device or inputs.device metadict = None if isinstance(inputs, MetaTensor): @@ -159,23 +163,26 @@ def sliding_window_inference( # Store all slices scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) slices = dense_patch_slices(image_size, roi_size, scan_interval, return_slice=False) + slices_np = np.asarray(slices) + if b_plane < 0: + b_plane += num_spatial_dims slices_np = slices_np[np.argsort(slices_np[:, b_plane, 0], kind="mergesort")] slices = [tuple(slice(c[0], c[1]) for c in i) for i in slices_np] - _, p_id, buffer_lens = np.unique(slices_np[:, b_plane, 0], return_counts=True, return_index=True) - b_se = [tuple(slices_np[i][b_plane]) for i in p_id] # buffer start & end along the b_plane - buffer_lens = np.repeat(buffer_lens, batch_size) - b_ends = np.cumsum(buffer_lens) + _, _p_id, _b_lens = np.unique(slices_np[:, b_plane, 0], return_counts=True, return_index=True) + b_se = [tuple(slices_np[i][b_plane]) for i in _p_id] # buffer start & end along the b_plane + b_ends = np.cumsum(np.repeat(_b_lens, batch_size)) # buffer flush boundaries num_win = len(slices) # number of windows per image total_slices = num_win * batch_size # total number of windows - windows_range = range(0, total_slices, sw_batch_size) - if b_steps is not None: - windows_range, s = [], 0 - for b in buffer_lens: - windows_range.append(range(s, s + b, sw_batch_size)) - s = s + b - windows_range = chain(*windows_range) + if not buffered: + windows_range = range(0, total_slices, sw_batch_size) + else: + b_steps = min(len(b_se), b_steps) + x = [0, *b_ends][::b_steps] + if x[-1] < b_ends[-1]: + x.append(b_ends[-1]) + windows_range = itertools.chain(*[range(x[i], x[i+1], sw_batch_size) for i in range(len(x) - 1)]) # Create window-level importance map valid_patch_size = get_valid_patch_size(image_size, roi_size) @@ -198,9 +205,10 @@ def sliding_window_inference( output_image_list, count_map_list, sw_device_buffer, b_s, b_i = [], [], [], 0, 0 # for each patch for slice_g in tqdm(windows_range) if progress else windows_range: - slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices if b_steps is None else b_ends[b_s])) + _cur_max = b_ends[b_s + b_steps - 1] if buffered else total_slices + slice_range = range(slice_g, min(slice_g + sw_batch_size, _cur_max)) unravel_slice = [ - [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) + [slice(idx // num_win, idx // num_win + 1), slice(None)] + list(slices[idx % num_win]) for idx in slice_range ] win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) @@ -213,41 +221,40 @@ def sliding_window_inference( else: importance_map = importance_map_ - if b_steps is not None: - if len(seg_tuple) > 1: - warnings.warn("Multiple outputs are not supported with buffer_steps, only the first output is used.") + if buffered: + # if len(seg_tuple) > 1: + # warnings.warn("Multiple outputs are not supported with buffer_steps") + c_start, c_end = b_se[b_s % len(b_se)], b_se[(b_s + b_steps - 1) % len(b_se)] if not sw_device_buffer: k = seg_tuple[0].shape[1] sp_size = list(image_size) - sp_size[b_plane] = roi_size[b_plane] # one step roi + sp_size[b_plane] = max(c_end[1] - c_start[0], roi_size[b_plane]) sw_device_buffer = [torch.zeros(size=[1, k, *sp_size], dtype=compute_dtype, device=sw_device)] importance_map = importance_map.to(dtype=compute_dtype, device=sw_device) - b_i = 0 for p, s in zip(seg_tuple[0], unravel_slice): - offset = s[b_plane + 2].start - b_se[b_s % len(b_se)][0] + offset = s[b_plane + 2].start - c_start[0] s[b_plane + 2] = slice(offset, offset + roi_size[b_plane]) s[0] = slice(0, 1) sw_device_buffer[0][s] += p * importance_map - b_i += 1 - if b_i < buffer_lens[b_s]: + b_i += len(unravel_slice) + if b_i < b_ends[b_s + b_steps - 1]: continue else: - sw_device_buffer = list(seg_tuple) + sw_device_buffer = seg_tuple for ss in range(len(sw_device_buffer)): b_shape = sw_device_buffer[ss].shape seg_chns, seg_shape = b_shape[1], b_shape[2:] - if b_steps is None and seg_shape != roi_size: - z_scale = [out_w_i / float(in_w_i) for out_w_i, in_w_i in zip(seg_shape, roi_size)] - else: + if buffered or seg_shape == roi_size: z_scale = None - if seg_shape != importance_map.shape and b_steps is None: # resizing the importance_map - resizer = Resize(spatial_size=seg_shape, mode="nearest", anti_aliasing=False) - w_t = resizer(importance_map.unsqueeze(0))[None].to(dtype=compute_dtype, device=sw_device) else: - w_t = importance_map[None, None].to(dtype=compute_dtype, device=sw_device) - if len(output_image_list) <= ss: # init. the ss-th buffer at the first iteration - # construct multi-resolution outputs + z_scale = [out_w_i / float(in_w_i) for out_w_i, in_w_i in zip(seg_shape, roi_size)] + if buffered or seg_shape == importance_map.shape: + w_t = importance_map.to(dtype=compute_dtype, device=sw_device) + else: # resizing the importance_map + resizer = Resize(spatial_size=seg_shape, mode="nearest", anti_aliasing=False) + w_t = resizer(importance_map.unsqueeze(0))[0].to(dtype=compute_dtype, device=sw_device) + if len(output_image_list) <= ss: output_shape = [batch_size, seg_chns] output_shape += [int(_i * _z) for _i, _z in zip(image_size, z_scale)] if z_scale else list(image_size) # allocate memory to store the full output and the count for overlapping parts @@ -259,19 +266,20 @@ def sliding_window_inference( __s = [slice(int(_si.start * z_s), int(_si.stop * z_s)) for _si, z_s in zip(__s, z_scale)] count_map_list[-1][(slice(None), slice(None), *__s)] += w_t w_t = w_t.to(sw_device) - if b_steps is not None: + if buffered: o_slice = [slice(None)] * len(inputs.shape) - o_slice[b_plane + 2] = slice(b_se[b_s % len(b_se)][0], b_se[b_s % len(b_se)][1]) - img_b = int(b_s / len(b_se)) # image batch index + o_slice[b_plane + 2] = slice(c_start[0], c_end[1]) + img_b = b_s // len(b_se) # image batch index o_slice[0] = slice(img_b, img_b + 1) output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device) - continue - sw_t = sw_device_buffer[ss] - sw_t *= w_t[0, 0] - sw_t = sw_t.to(device) - _compute_coords(sw_batch_size, unravel_slice, z_scale, output_image_list[ss], sw_t) - b_s += 1 - sw_device_buffer, b_i = None, 0 + else: + sw_t = sw_device_buffer[ss] + sw_t *= w_t + sw_t = sw_t.to(device) + _compute_coords(sw_batch_size, unravel_slice, z_scale, output_image_list[ss], sw_t) + sw_device_buffer = None + if buffered: + b_s += b_steps # account for any overlapping sections for ss in range(len(output_image_list)): From 0c11323e04d6aaa162a68306177c0fcf3af42786 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 31 Mar 2023 02:30:51 +0100 Subject: [PATCH 04/30] update Signed-off-by: Wenqi Li --- monai/data/utils.py | 2 +- monai/inferers/utils.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 872cca1da9..7f09ed3921 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -203,7 +203,7 @@ def dense_patch_slices( out = np.asarray([x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T if return_slice: return [tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out] - return [tuple((s, s + patch_size[d]) for d, s in enumerate(x)) for x in out] + return [tuple((s, s + patch_size[d]) for d, s in enumerate(x)) for x in out] # type: ignore def iter_patch_position( diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index f616ff1584..1edcd5d0fb 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -11,9 +11,9 @@ from __future__ import annotations -from collections.abc import Callable, Mapping, Sequence import itertools -from typing import Any +from collections.abc import Callable, Mapping, Sequence +from typing import Any, Iterable import numpy as np import torch @@ -130,7 +130,7 @@ def sliding_window_inference( buffered = b_steps is not None and b_steps > 0 num_spatial_dims = len(inputs.shape) - 2 if buffered: - if (b_plane < -num_spatial_dims + 1 or b_plane > num_spatial_dims): + if b_plane < -num_spatial_dims + 1 or b_plane > num_spatial_dims: raise ValueError(f"buffer_dim must be in [{-num_spatial_dims + 1}, {num_spatial_dims}], got {b_plane}.") if b_steps <= 0: raise ValueError(f"buffer_steps must be >= 0, got {b_steps}.") @@ -175,6 +175,7 @@ def sliding_window_inference( num_win = len(slices) # number of windows per image total_slices = num_win * batch_size # total number of windows + windows_range: Iterable if not buffered: windows_range = range(0, total_slices, sw_batch_size) else: @@ -182,7 +183,7 @@ def sliding_window_inference( x = [0, *b_ends][::b_steps] if x[-1] < b_ends[-1]: x.append(b_ends[-1]) - windows_range = itertools.chain(*[range(x[i], x[i+1], sw_batch_size) for i in range(len(x) - 1)]) + windows_range = itertools.chain(*[range(x[i], x[i + 1], sw_batch_size) for i in range(len(x) - 1)]) # Create window-level importance map valid_patch_size = get_valid_patch_size(image_size, roi_size) @@ -202,7 +203,7 @@ def sliding_window_inference( importance_map_ = convert_data_type(importance_map_, torch.Tensor, device=sw_device, dtype=compute_dtype)[0] # stores output and count map - output_image_list, count_map_list, sw_device_buffer, b_s, b_i = [], [], [], 0, 0 + output_image_list, count_map_list, sw_device_buffer, b_s, b_i = [], [], [], 0, 0 # type: ignore # for each patch for slice_g in tqdm(windows_range) if progress else windows_range: _cur_max = b_ends[b_s + b_steps - 1] if buffered else total_slices @@ -263,7 +264,7 @@ def sliding_window_inference( w_t = w_t.to(device) for __s in slices: if z_scale is not None: - __s = [slice(int(_si.start * z_s), int(_si.stop * z_s)) for _si, z_s in zip(__s, z_scale)] + __s = tuple(slice(int(_si.start * z_s), int(_si.stop * z_s)) for _si, z_s in zip(__s, z_scale)) count_map_list[-1][(slice(None), slice(None), *__s)] += w_t w_t = w_t.to(sw_device) if buffered: @@ -277,7 +278,7 @@ def sliding_window_inference( sw_t *= w_t sw_t = sw_t.to(device) _compute_coords(sw_batch_size, unravel_slice, z_scale, output_image_list[ss], sw_t) - sw_device_buffer = None + sw_device_buffer = [] if buffered: b_s += b_steps From 011854fecc885ae5877bc5473c7373dde80a6a61 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 31 Mar 2023 06:18:25 -0400 Subject: [PATCH 05/30] fixes copying Signed-off-by: Wenqi Li --- monai/inferers/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 1edcd5d0fb..85e3b7a991 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -144,9 +144,9 @@ def sliding_window_inference( device = device or inputs.device sw_device = sw_device or inputs.device - metadict = None + temp_meta = None if isinstance(inputs, MetaTensor): - metadict = inputs.meta.copy() + temp_meta = MetaTensor([]).copy_meta_from(inputs, copy_attr=False) inputs = convert_data_type(inputs, torch.Tensor, wrap_sequence=True)[0] roi_size = fall_back_tuple(roi_size, image_size_) @@ -303,8 +303,8 @@ def sliding_window_inference( final_output = _pack_struct(output_image_list, dict_keys) final_output = convert_to_dst_type(final_output, inputs, device=device)[0] # type: ignore - if metadict is not None: - final_output = MetaTensor(final_output, meta=metadict) + if temp_meta is not None: + final_output = MetaTensor(final_output).copy_meta_from(temp_meta) return final_output # type: ignore From 9bb2b4da5618586ba92f238b367fa2e5ecab0509 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 31 Mar 2023 11:36:03 -0400 Subject: [PATCH 06/30] more tests Signed-off-by: Wenqi Li --- monai/inferers/utils.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 85e3b7a991..4c752c8b07 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -171,7 +171,7 @@ def sliding_window_inference( slices = [tuple(slice(c[0], c[1]) for c in i) for i in slices_np] _, _p_id, _b_lens = np.unique(slices_np[:, b_plane, 0], return_counts=True, return_index=True) b_se = [tuple(slices_np[i][b_plane]) for i in _p_id] # buffer start & end along the b_plane - b_ends = np.cumsum(np.repeat(_b_lens, batch_size)) # buffer flush boundaries + b_ends = np.cumsum(_b_lens) # possbile buffer flush boundaries num_win = len(slices) # number of windows per image total_slices = num_win * batch_size # total number of windows @@ -183,7 +183,14 @@ def sliding_window_inference( x = [0, *b_ends][::b_steps] if x[-1] < b_ends[-1]: x.append(b_ends[-1]) - windows_range = itertools.chain(*[range(x[i], x[i + 1], sw_batch_size) for i in range(len(x) - 1)]) + n_per_batch = len(x) - 1 + windows_range, b_ends = [], [0] + for b in range(batch_size): + offset = b * x[-1] + for i in range(n_per_batch): + windows_range.append(range(offset + x[i], offset + x[i + 1], sw_batch_size)) + b_ends.append(offset + x[i + 1]) + windows_range = itertools.chain(*windows_range) # Create window-level importance map valid_patch_size = get_valid_patch_size(image_size, roi_size) @@ -206,7 +213,7 @@ def sliding_window_inference( output_image_list, count_map_list, sw_device_buffer, b_s, b_i = [], [], [], 0, 0 # type: ignore # for each patch for slice_g in tqdm(windows_range) if progress else windows_range: - _cur_max = b_ends[b_s + b_steps - 1] if buffered else total_slices + _cur_max = b_ends[b_s + 1] if buffered else total_slices slice_range = range(slice_g, min(slice_g + sw_batch_size, _cur_max)) unravel_slice = [ [slice(idx // num_win, idx // num_win + 1), slice(None)] + list(slices[idx % num_win]) @@ -223,22 +230,21 @@ def sliding_window_inference( importance_map = importance_map_ if buffered: - # if len(seg_tuple) > 1: - # warnings.warn("Multiple outputs are not supported with buffer_steps") - c_start, c_end = b_se[b_s % len(b_se)], b_se[(b_s + b_steps - 1) % len(b_se)] + c_start = slices_np[b_ends[b_s] % num_win, b_plane, 0] + c_end = slices_np[(b_ends[b_s + 1] - 1) % num_win, b_plane, 1] if not sw_device_buffer: - k = seg_tuple[0].shape[1] + k = seg_tuple[0].shape[1] # len(seg_tuple) > 1 is currently ignored sp_size = list(image_size) - sp_size[b_plane] = max(c_end[1] - c_start[0], roi_size[b_plane]) + sp_size[b_plane] = c_end - c_start sw_device_buffer = [torch.zeros(size=[1, k, *sp_size], dtype=compute_dtype, device=sw_device)] importance_map = importance_map.to(dtype=compute_dtype, device=sw_device) for p, s in zip(seg_tuple[0], unravel_slice): - offset = s[b_plane + 2].start - c_start[0] + offset = s[b_plane + 2].start - c_start s[b_plane + 2] = slice(offset, offset + roi_size[b_plane]) s[0] = slice(0, 1) sw_device_buffer[0][s] += p * importance_map b_i += len(unravel_slice) - if b_i < b_ends[b_s + b_steps - 1]: + if b_i < b_ends[b_s + 1]: continue else: sw_device_buffer = seg_tuple @@ -269,8 +275,8 @@ def sliding_window_inference( w_t = w_t.to(sw_device) if buffered: o_slice = [slice(None)] * len(inputs.shape) - o_slice[b_plane + 2] = slice(c_start[0], c_end[1]) - img_b = b_s // len(b_se) # image batch index + o_slice[b_plane + 2] = slice(c_start, c_end) + img_b = b_s // n_per_batch # image batch index o_slice[0] = slice(img_b, img_b + 1) output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device) else: @@ -280,7 +286,7 @@ def sliding_window_inference( _compute_coords(sw_batch_size, unravel_slice, z_scale, output_image_list[ss], sw_t) sw_device_buffer = [] if buffered: - b_s += b_steps + b_s += 1 # account for any overlapping sections for ss in range(len(output_image_list)): From 857d28638fe33760cbcc7b2bff1d4e053a3ac7da Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 31 Mar 2023 17:05:18 +0100 Subject: [PATCH 07/30] update Signed-off-by: Wenqi Li --- monai/inferers/utils.py | 55 ++++++++++++-------------- tests/test_sliding_window_inference.py | 2 + 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 4c752c8b07..019074d652 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -53,6 +53,8 @@ def sliding_window_inference( progress: bool = False, roi_weight_map: torch.Tensor | None = None, process_fn: Callable | None = None, + buffer_steps: int | None = None, + buffer_dim: int = 0, *args: Any, **kwargs: Any, ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]: @@ -114,26 +116,23 @@ def sliding_window_inference( roi_weight_map: pre-computed (non-negative) weight map for each ROI. If not given, and ``mode`` is not `constant`, this map will be computed on the fly. process_fn: process inference output and adjust the importance map per window + buffer_steps: the number of sliding window iterations before writing the outputs to ``device``. + default is None, no buffer. + buffer_dim: the dimension along which the buffer are created, default is 0. args: optional args to be passed to ``predictor``. kwargs: optional keyword args to be passed to ``predictor``. - - buffer_steps: the number of sliding window iterations before writing the outputs to ``device``. - default is None, no buffer. - - buffer_dim: the dimension along which the buffer are created, default is 0. - Note: - input must be channel-first and have a batch dim, supports N-D sliding window. """ - b_steps = kwargs.pop("buffer_steps", None) - b_plane = kwargs.pop("buffer_dim", 0) - buffered = b_steps is not None and b_steps > 0 + buffered = buffer_steps is not None and buffer_steps > 0 num_spatial_dims = len(inputs.shape) - 2 if buffered: - if b_plane < -num_spatial_dims + 1 or b_plane > num_spatial_dims: - raise ValueError(f"buffer_dim must be in [{-num_spatial_dims + 1}, {num_spatial_dims}], got {b_plane}.") - if b_steps <= 0: - raise ValueError(f"buffer_steps must be >= 0, got {b_steps}.") + if buffer_dim < -num_spatial_dims + 1 or buffer_dim > num_spatial_dims: + raise ValueError(f"buffer_dim must be in [{-num_spatial_dims + 1}, {num_spatial_dims}], got {buffer_dim}.") + if buffer_steps <= 0: # type: ignore + raise ValueError(f"buffer_steps must be >= 0, got {buffer_steps}.") if overlap < 0 or overlap >= 1: raise ValueError(f"overlap must be >= 0 and < 1, got {overlap}.") compute_dtype = inputs.dtype @@ -165,13 +164,13 @@ def sliding_window_inference( slices = dense_patch_slices(image_size, roi_size, scan_interval, return_slice=False) slices_np = np.asarray(slices) - if b_plane < 0: - b_plane += num_spatial_dims - slices_np = slices_np[np.argsort(slices_np[:, b_plane, 0], kind="mergesort")] + if buffer_dim < 0: + buffer_dim += num_spatial_dims + slices_np = slices_np[np.argsort(slices_np[:, buffer_dim, 0], kind="mergesort")] slices = [tuple(slice(c[0], c[1]) for c in i) for i in slices_np] - _, _p_id, _b_lens = np.unique(slices_np[:, b_plane, 0], return_counts=True, return_index=True) - b_se = [tuple(slices_np[i][b_plane]) for i in _p_id] # buffer start & end along the b_plane - b_ends = np.cumsum(_b_lens) # possbile buffer flush boundaries + _, _p_id, _b_lens = np.unique(slices_np[:, buffer_dim, 0], return_counts=True, return_index=True) + _b_se = [tuple(slices_np[i][buffer_dim]) for i in _p_id] # buffer start & end along the buffer_dim + b_ends = np.cumsum(_b_lens).tolist() # possible buffer flush boundaries num_win = len(slices) # number of windows per image total_slices = num_win * batch_size # total number of windows @@ -179,12 +178,11 @@ def sliding_window_inference( if not buffered: windows_range = range(0, total_slices, sw_batch_size) else: - b_steps = min(len(b_se), b_steps) - x = [0, *b_ends][::b_steps] + buffer_steps = min(len(_b_se), int(buffer_steps)) # type: ignore + x = [0, *b_ends][::buffer_steps] if x[-1] < b_ends[-1]: x.append(b_ends[-1]) - n_per_batch = len(x) - 1 - windows_range, b_ends = [], [0] + windows_range, n_per_batch, b_ends = [], len(x) - 1, [0] for b in range(batch_size): offset = b * x[-1] for i in range(n_per_batch): @@ -213,8 +211,7 @@ def sliding_window_inference( output_image_list, count_map_list, sw_device_buffer, b_s, b_i = [], [], [], 0, 0 # type: ignore # for each patch for slice_g in tqdm(windows_range) if progress else windows_range: - _cur_max = b_ends[b_s + 1] if buffered else total_slices - slice_range = range(slice_g, min(slice_g + sw_batch_size, _cur_max)) + slice_range = range(slice_g, min(slice_g + sw_batch_size, b_ends[b_s + 1] if buffered else total_slices)) unravel_slice = [ [slice(idx // num_win, idx // num_win + 1), slice(None)] + list(slices[idx % num_win]) for idx in slice_range @@ -230,17 +227,17 @@ def sliding_window_inference( importance_map = importance_map_ if buffered: - c_start = slices_np[b_ends[b_s] % num_win, b_plane, 0] - c_end = slices_np[(b_ends[b_s + 1] - 1) % num_win, b_plane, 1] + c_start = slices_np[b_ends[b_s] % num_win, buffer_dim, 0] + c_end = slices_np[(b_ends[b_s + 1] - 1) % num_win, buffer_dim, 1] if not sw_device_buffer: k = seg_tuple[0].shape[1] # len(seg_tuple) > 1 is currently ignored sp_size = list(image_size) - sp_size[b_plane] = c_end - c_start + sp_size[buffer_dim] = c_end - c_start sw_device_buffer = [torch.zeros(size=[1, k, *sp_size], dtype=compute_dtype, device=sw_device)] importance_map = importance_map.to(dtype=compute_dtype, device=sw_device) for p, s in zip(seg_tuple[0], unravel_slice): - offset = s[b_plane + 2].start - c_start - s[b_plane + 2] = slice(offset, offset + roi_size[b_plane]) + offset = s[buffer_dim + 2].start - c_start + s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim]) s[0] = slice(0, 1) sw_device_buffer[0][s] += p * importance_map b_i += len(unravel_slice) @@ -275,7 +272,7 @@ def sliding_window_inference( w_t = w_t.to(sw_device) if buffered: o_slice = [slice(None)] * len(inputs.shape) - o_slice[b_plane + 2] = slice(c_start, c_end) + o_slice[buffer_dim + 2] = slice(c_start, c_end) img_b = b_s // n_per_batch # image batch index o_slice[0] = slice(img_b, img_b + 1) output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device) diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index 5f07084927..36a62d653f 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -244,6 +244,8 @@ def compute(data, test1, test2): has_tqdm, None, None, + None, + 0, t1, test2=t2, ) From bbb91fe068682f965d417fb12a9cd5ed05c3280d Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 31 Mar 2023 17:14:14 +0100 Subject: [PATCH 08/30] update api Signed-off-by: Wenqi Li --- monai/inferers/inferer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 012a238c89..f56c64a43f 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -366,6 +366,9 @@ class SlidingWindowInferer(Inferer): cpu_thresh: when provided, dynamically switch to stitching on cpu (to save gpu memory) when input image volume is larger than this threshold (in pixels/voxels). Otherwise use ``"device"``. Thus, the output may end-up on either cpu or gpu. + buffer_steps: the number of sliding window iterations before writing the outputs to ``device``. + default is None, no buffer. + buffer_dim: the dimension along which the buffer are created, default is 0. Note: ``sw_batch_size`` denotes the max number of windows per network inference iteration, @@ -387,6 +390,8 @@ def __init__( progress: bool = False, cache_roi_weight_map: bool = False, cpu_thresh: int | None = None, + buffer_steps: int | None = None, + buffer_dim: int = 0, ) -> None: super().__init__() self.roi_size = roi_size @@ -400,6 +405,8 @@ def __init__( self.device = device self.progress = progress self.cpu_thresh = cpu_thresh + self.buffer_steps = buffer_steps + self.buffer_dim = buffer_dim # compute_importance_map takes long time when computing on cpu. We thus # compute it once if it's static and then save it for future usage @@ -441,6 +448,7 @@ def __call__( if device is None and self.cpu_thresh is not None and inputs.shape[2:].numel() > self.cpu_thresh: device = "cpu" # stitch in cpu memory if image is too large + return sliding_window_inference( inputs, self.roi_size, @@ -456,6 +464,8 @@ def __call__( self.progress, self.roi_weight_map, None, + self.buffer_steps, + self.buffer_dim, *args, **kwargs, ) From efd796a89fc4d616b7eaaa1b1944c2142d0ef2a9 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 31 Mar 2023 17:15:58 +0100 Subject: [PATCH 09/30] update Signed-off-by: Wenqi Li --- monai/inferers/inferer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index f56c64a43f..952872b5ba 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -448,7 +448,6 @@ def __call__( if device is None and self.cpu_thresh is not None and inputs.shape[2:].numel() > self.cpu_thresh: device = "cpu" # stitch in cpu memory if image is too large - return sliding_window_inference( inputs, self.roi_size, From 9a36eeebd54e5765b78ed4534d1c0e2747e9cf4e Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 31 Mar 2023 17:42:55 +0100 Subject: [PATCH 10/30] fixes test cases Signed-off-by: Wenqi Li --- monai/apps/pathology/inferers/inferer.py | 2 ++ tests/test_sliding_window_hovernet_inference.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/monai/apps/pathology/inferers/inferer.py b/monai/apps/pathology/inferers/inferer.py index da4ac4fd7a..7a60c23aa2 100644 --- a/monai/apps/pathology/inferers/inferer.py +++ b/monai/apps/pathology/inferers/inferer.py @@ -176,6 +176,8 @@ def __call__( self.progress, self.roi_weight_map, self.process_output, + self.buffer_steps, + self.buffer_dim, *args, **kwargs, ) diff --git a/tests/test_sliding_window_hovernet_inference.py b/tests/test_sliding_window_hovernet_inference.py index 0dc2216c22..8f7f8346cc 100644 --- a/tests/test_sliding_window_hovernet_inference.py +++ b/tests/test_sliding_window_hovernet_inference.py @@ -232,6 +232,8 @@ def compute(data, test1, test2): has_tqdm, None, None, + None, + 0, t1, test2=t2, ) From e52effe866b155dfaa8244830d58f269e2c363a9 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 1 Apr 2023 05:54:23 -0400 Subject: [PATCH 11/30] optimize Signed-off-by: Wenqi Li --- .gitignore | 1 + monai/inferers/utils.py | 136 +++++++++++++++++++++------------------- monai/utils/module.py | 3 + 3 files changed, 75 insertions(+), 65 deletions(-) diff --git a/.gitignore b/.gitignore index 4e235c7774..2fd28bbfc7 100644 --- a/.gitignore +++ b/.gitignore @@ -149,3 +149,4 @@ tests/testing_data/CT_2D_head_moving.mha # profiling results *.prof +runs diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 019074d652..8da2bdc237 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -21,7 +21,6 @@ from monai.data.meta_tensor import MetaTensor from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size -from monai.transforms import Resize from monai.utils import ( BlendMode, PytorchPadMode, @@ -31,9 +30,11 @@ fall_back_tuple, look_up_option, optional_import, + pytorch_after, ) tqdm, _ = optional_import("tqdm", name="tqdm") +_nearest_mode = "nearest-exact" if pytorch_after(1, 11) else "nearest" __all__ = ["sliding_window_inference"] @@ -54,7 +55,7 @@ def sliding_window_inference( roi_weight_map: torch.Tensor | None = None, process_fn: Callable | None = None, buffer_steps: int | None = None, - buffer_dim: int = 0, + buffer_dim: int = -1, *args: Any, **kwargs: Any, ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]: @@ -117,8 +118,9 @@ def sliding_window_inference( If not given, and ``mode`` is not `constant`, this map will be computed on the fly. process_fn: process inference output and adjust the importance map per window buffer_steps: the number of sliding window iterations before writing the outputs to ``device``. - default is None, no buffer. - buffer_dim: the dimension along which the buffer are created, default is 0. + default is None, no buffering. + buffer_dim: the spatial dimension along which the buffer are created. + 0 indicates the first spatial dimension. Default is -1, the last spatial dimension. args: optional args to be passed to ``predictor``. kwargs: optional keyword args to be passed to ``predictor``. @@ -129,10 +131,12 @@ def sliding_window_inference( buffered = buffer_steps is not None and buffer_steps > 0 num_spatial_dims = len(inputs.shape) - 2 if buffered: - if buffer_dim < -num_spatial_dims + 1 or buffer_dim > num_spatial_dims: - raise ValueError(f"buffer_dim must be in [{-num_spatial_dims + 1}, {num_spatial_dims}], got {buffer_dim}.") + if buffer_dim < -num_spatial_dims or buffer_dim > num_spatial_dims: + raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.") if buffer_steps <= 0: # type: ignore raise ValueError(f"buffer_steps must be >= 0, got {buffer_steps}.") + if buffer_dim < 0: + buffer_dim += num_spatial_dims if overlap < 0 or overlap >= 1: raise ValueError(f"overlap must be >= 0 and < 1, got {overlap}.") compute_dtype = inputs.dtype @@ -161,16 +165,7 @@ def sliding_window_inference( # Store all slices scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) - slices = dense_patch_slices(image_size, roi_size, scan_interval, return_slice=False) - - slices_np = np.asarray(slices) - if buffer_dim < 0: - buffer_dim += num_spatial_dims - slices_np = slices_np[np.argsort(slices_np[:, buffer_dim, 0], kind="mergesort")] - slices = [tuple(slice(c[0], c[1]) for c in i) for i in slices_np] - _, _p_id, _b_lens = np.unique(slices_np[:, buffer_dim, 0], return_counts=True, return_index=True) - _b_se = [tuple(slices_np[i][buffer_dim]) for i in _p_id] # buffer start & end along the buffer_dim - b_ends = np.cumsum(_b_lens).tolist() # possible buffer flush boundaries + slices = dense_patch_slices(image_size, roi_size, scan_interval, return_slice=not buffered) num_win = len(slices) # number of windows per image total_slices = num_win * batch_size # total number of windows @@ -178,17 +173,9 @@ def sliding_window_inference( if not buffered: windows_range = range(0, total_slices, sw_batch_size) else: - buffer_steps = min(len(_b_se), int(buffer_steps)) # type: ignore - x = [0, *b_ends][::buffer_steps] - if x[-1] < b_ends[-1]: - x.append(b_ends[-1]) - windows_range, n_per_batch, b_ends = [], len(x) - 1, [0] - for b in range(batch_size): - offset = b * x[-1] - for i in range(n_per_batch): - windows_range.append(range(offset + x[i], offset + x[i + 1], sw_batch_size)) - b_ends.append(offset + x[i + 1]) - windows_range = itertools.chain(*windows_range) + slices, n_per_batch, b_slices, windows_range = _create_buffered_slices( + slices, batch_size, sw_batch_size, buffer_dim, buffer_steps + ) # Create window-level importance map valid_patch_size = get_valid_patch_size(image_size, roi_size) @@ -211,7 +198,7 @@ def sliding_window_inference( output_image_list, count_map_list, sw_device_buffer, b_s, b_i = [], [], [], 0, 0 # type: ignore # for each patch for slice_g in tqdm(windows_range) if progress else windows_range: - slice_range = range(slice_g, min(slice_g + sw_batch_size, b_ends[b_s + 1] if buffered else total_slices)) + slice_range = range(slice_g, min(slice_g + sw_batch_size, b_slices[b_s][0] if buffered else total_slices)) unravel_slice = [ [slice(idx // num_win, idx // num_win + 1), slice(None)] + list(slices[idx % num_win]) for idx in slice_range @@ -222,54 +209,48 @@ def sliding_window_inference( # convert seg_prob_out to tuple seg_tuple, this does not allocate new memory. dict_keys, seg_tuple = _flatten_struct(seg_prob_out) if process_fn: - seg_tuple, importance_map = process_fn(seg_tuple, win_data, importance_map_) + seg_tuple, w_t = process_fn(seg_tuple, win_data, importance_map_) else: - importance_map = importance_map_ - + w_t = importance_map_ + if len(w_t.shape) == num_spatial_dims: + w_t = w_t[None, None] + w_t = w_t.to(dtype=compute_dtype, device=sw_device) if buffered: - c_start = slices_np[b_ends[b_s] % num_win, buffer_dim, 0] - c_end = slices_np[(b_ends[b_s + 1] - 1) % num_win, buffer_dim, 1] + c_start, c_end = b_slices[b_s][1:] if not sw_device_buffer: k = seg_tuple[0].shape[1] # len(seg_tuple) > 1 is currently ignored sp_size = list(image_size) sp_size[buffer_dim] = c_end - c_start sw_device_buffer = [torch.zeros(size=[1, k, *sp_size], dtype=compute_dtype, device=sw_device)] - importance_map = importance_map.to(dtype=compute_dtype, device=sw_device) for p, s in zip(seg_tuple[0], unravel_slice): offset = s[buffer_dim + 2].start - c_start s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim]) s[0] = slice(0, 1) - sw_device_buffer[0][s] += p * importance_map + sw_device_buffer[0][s] += p * w_t b_i += len(unravel_slice) - if b_i < b_ends[b_s + 1]: + if b_i < b_slices[b_s][0]: continue else: - sw_device_buffer = seg_tuple + sw_device_buffer = list(seg_tuple) for ss in range(len(sw_device_buffer)): b_shape = sw_device_buffer[ss].shape seg_chns, seg_shape = b_shape[1], b_shape[2:] - if buffered or seg_shape == roi_size: - z_scale = None - else: + z_scale = None + if not buffered and seg_shape != roi_size: z_scale = [out_w_i / float(in_w_i) for out_w_i, in_w_i in zip(seg_shape, roi_size)] - if buffered or seg_shape == importance_map.shape: - w_t = importance_map.to(dtype=compute_dtype, device=sw_device) - else: # resizing the importance_map - resizer = Resize(spatial_size=seg_shape, mode="nearest", anti_aliasing=False) - w_t = resizer(importance_map.unsqueeze(0))[0].to(dtype=compute_dtype, device=sw_device) + w_t = torch.nn.functional.interpolate(w_t, seg_shape, mode=_nearest_mode) if len(output_image_list) <= ss: output_shape = [batch_size, seg_chns] output_shape += [int(_i * _z) for _i, _z in zip(image_size, z_scale)] if z_scale else list(image_size) # allocate memory to store the full output and the count for overlapping parts output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device=device)) count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device)) - w_t = w_t.to(device) + w_t_ = w_t.to(device) for __s in slices: if z_scale is not None: __s = tuple(slice(int(_si.start * z_s), int(_si.stop * z_s)) for _si, z_s in zip(__s, z_scale)) - count_map_list[-1][(slice(None), slice(None), *__s)] += w_t - w_t = w_t.to(sw_device) + count_map_list[-1][(slice(None), slice(None), *__s)] += w_t_ if buffered: o_slice = [slice(None)] * len(inputs.shape) o_slice[buffer_dim + 2] = slice(c_start, c_end) @@ -277,10 +258,9 @@ def sliding_window_inference( o_slice[0] = slice(img_b, img_b + 1) output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device) else: - sw_t = sw_device_buffer[ss] - sw_t *= w_t - sw_t = sw_t.to(device) - _compute_coords(sw_batch_size, unravel_slice, z_scale, output_image_list[ss], sw_t) + sw_device_buffer[ss] *= w_t + sw_device_buffer[ss] = sw_device_buffer[ss].to(device) + _compute_coords(sw_batch_size, unravel_slice, z_scale, output_image_list[ss], sw_device_buffer[ss]) sw_device_buffer = [] if buffered: b_s += 1 @@ -290,19 +270,18 @@ def sliding_window_inference( output_image_list[ss] /= count_map_list.pop(0) # remove padding if image_size smaller than roi_size - for ss, output_i in enumerate(output_image_list): - zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)] - final_slicing: list[slice] = [] - for sp in range(num_spatial_dims): - si = num_spatial_dims - sp - 1 - slice_dim = slice( - int(round(pad_size[sp * 2] * zoom_scale[si])), - int(round((pad_size[sp * 2] + image_size_[si]) * zoom_scale[si])), - ) - final_slicing.insert(0, slice_dim) - while len(final_slicing) < len(output_i.shape): - final_slicing.insert(0, slice(None)) - output_image_list[ss] = output_i[final_slicing] + if any(pad_size): + for ss, output_i in enumerate(output_image_list): + zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)] + final_slicing: list[slice] = [] + for sp in range(num_spatial_dims): + si = num_spatial_dims - sp - 1 + slice_dim = slice( + int(round(pad_size[sp * 2] * zoom_scale[si])), + int(round((pad_size[sp * 2] + image_size_[si]) * zoom_scale[si])), + ) + final_slicing.insert(0, slice_dim) + output_image_list[ss] = output_i[(slice(None), slice(None), *final_slicing)] final_output = _pack_struct(output_image_list, dict_keys) final_output = convert_to_dst_type(final_output, inputs, device=device)[0] # type: ignore @@ -311,6 +290,33 @@ def sliding_window_inference( return final_output # type: ignore +def _create_buffered_slices(slices, batch_size, sw_batch_size, buffer_dim, buffer_steps): + """rearrange slices for buffering""" + slices_np = np.asarray(slices) + slices_np = slices_np[np.argsort(slices_np[:, buffer_dim, 0], kind="mergesort")] + slices = [tuple(slice(c[0], c[1]) for c in i) for i in slices_np] + slices_np = slices_np[:, buffer_dim] + + _, _, _b_lens = np.unique(slices_np[:, 0], return_counts=True, return_index=True) + b_ends = np.cumsum(_b_lens).tolist() # possible buffer flush boundaries + x = [0, *b_ends][:: min(len(b_ends), int(buffer_steps))] # type: ignore + if x[-1] < b_ends[-1]: + x.append(b_ends[-1]) + n_per_batch = len(x) - 1 + windows_range = [ + range(b * x[-1] + x[i], b * x[-1] + x[i + 1], sw_batch_size) + for b in range(batch_size) + for i in range(n_per_batch) + ] + b_slices = [] + for _s, _r in enumerate(windows_range): + s_s = slices_np[windows_range[_s - 1].stop % len(slices) if _s > 0 else 0, 0] + s_e = slices_np[(_r.stop - 1) % len(slices), 1] + b_slices.append((_r.stop, s_s, s_e)) # buffer index, slice start, slice end + windows_range = itertools.chain(*windows_range) # type: ignore + return slices, n_per_batch, b_slices, windows_range + + def _compute_coords(sw, coords, z_scale, out, patch): """sliding window batch spatial scaling indexing for multi-resolution outputs.""" for original_idx, p in zip(coords, patch): diff --git a/monai/utils/module.py b/monai/utils/module.py index b72e3ff139..fcd5e04145 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -12,6 +12,7 @@ from __future__ import annotations import enum +import functools import os import pdb import re @@ -508,6 +509,7 @@ def get_package_version(dep_name, default="NOT INSTALLED or UNKNOWN VERSION."): return default +@functools.lru_cache(None) def get_torch_version_tuple(): """ Returns: @@ -562,6 +564,7 @@ def _try_cast(val: str) -> int | str: return True +@functools.lru_cache(None) def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: str | None = None) -> bool: """ Compute whether the current pytorch version is after or equal to the specified version. From 2b0c25bdde17b34027d2fd7227c66aea2b310dd7 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 1 Apr 2023 11:30:27 -0400 Subject: [PATCH 12/30] adds test cases Signed-off-by: Wenqi Li --- monai/inferers/utils.py | 17 ++++---- tests/test_sliding_window_inference.py | 55 +++++++++++++++++++++++++- 2 files changed, 64 insertions(+), 8 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 8da2bdc237..a2d8c4b40e 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -27,6 +27,7 @@ convert_data_type, convert_to_dst_type, ensure_tuple, + ensure_tuple_rep, fall_back_tuple, look_up_option, optional_import, @@ -44,7 +45,7 @@ def sliding_window_inference( roi_size: Sequence[int] | int, sw_batch_size: int, predictor: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]], - overlap: float = 0.25, + overlap: Sequence[float] | float = 0.25, mode: BlendMode | str = BlendMode.CONSTANT, sigma_scale: Sequence[float] | float = 0.125, padding_mode: PytorchPadMode | str = PytorchPadMode.CONSTANT, @@ -91,7 +92,7 @@ def sliding_window_inference( to ensure the scaled output ROI sizes are still integers. If the `predictor`'s input and output spatial sizes are different, we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension. - overlap: Amount of overlap between scans. + overlap: Amount of overlap between scans along each spatial dimension, defaults to ``0.25``. mode: {``"constant"``, ``"gaussian"``} How to blend output of overlapping windows. Defaults to ``"constant"``. @@ -137,8 +138,10 @@ def sliding_window_inference( raise ValueError(f"buffer_steps must be >= 0, got {buffer_steps}.") if buffer_dim < 0: buffer_dim += num_spatial_dims - if overlap < 0 or overlap >= 1: - raise ValueError(f"overlap must be >= 0 and < 1, got {overlap}.") + overlap = ensure_tuple_rep(overlap, num_spatial_dims) + for o in overlap: + if o < 0 or o >= 1: + raise ValueError(f"overlap must be >= 0 and < 1, got {overlap}.") compute_dtype = inputs.dtype # determine image spatial size and batch size @@ -330,7 +333,7 @@ def _compute_coords(sw, coords, z_scale, out, patch): def _get_scan_interval( - image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float + image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: Sequence[float] ) -> tuple[int, ...]: """ Compute scan interval according to the image size, roi size and overlap. @@ -344,11 +347,11 @@ def _get_scan_interval( raise ValueError(f"len(roi_size) {len(roi_size)} different from spatial dims {num_spatial_dims}.") scan_interval = [] - for i in range(num_spatial_dims): + for i, o in zip(range(num_spatial_dims), overlap): if roi_size[i] == image_size[i]: scan_interval.append(int(roi_size[i])) else: - interval = int(roi_size[i] * (1 - overlap)) + interval = int(roi_size[i] * (1 - o)) scan_interval.append(interval if interval > 0 else 1) return tuple(scan_interval) diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index 36a62d653f..5e445a2789 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -21,7 +21,7 @@ from monai.data.utils import list_data_collate from monai.inferers import SlidingWindowInferer, sliding_window_inference from monai.utils import optional_import -from tests.utils import TEST_TORCH_AND_META_TENSORS, skip_if_no_cuda +from tests.utils import TEST_TORCH_AND_META_TENSORS, skip_if_no_cuda, test_is_quick _, has_tqdm = optional_import("tqdm") @@ -45,8 +45,61 @@ [(5, 3, 16, 15, 7), (4, 1, 7), 3, 0.25, "constant", torch.device("cpu:0")], # 3D small roi ] +_devices = [["cpu", "cuda:0"]] if torch.cuda.is_available() else [["cpu"]] +_windows = [ + [(2, 3, 10, 11), (7, 10), 0.8, 5], + [(2, 3, 10, 11), (15, 12), 0, 2], + [(2, 3, 10, 11), (10, 11), 0, 3], + [(2, 3, 511, 237), (96, 80), 0.4, 5], + [(2, 3, 512, 245), (96, 80), 0, 5], + [(2, 3, 512, 245), (512, 80), 0.125, 5], + [(2, 3, 10, 11, 12), (7, 8, 10), 0.2, 2], +] +if not test_is_quick(): + _windows += [ + [(2, 1, 125, 512, 200), (96, 97, 98), (0.4, 0.32, 0), 20], + [(2, 1, 10, 512, 200), (96, 97, 98), (0.4, 0.12, 0), 21], + ] + +BUFFER_CASES = [] +for x in _windows: + for s in (1, 3, 5): + for d in (-1, 0, 1): + BUFFER_CASES.extend([x, s, d, dev] for dev in itertools.product(*_devices * 3)) + class TestSlidingWindowInference(unittest.TestCase): + @parameterized.expand(BUFFER_CASES) + def test_buffers(self, size_params, buffer_steps, buffer_dim, device_params): + def mult_two(patch, *args, **kwargs): + return 2.0 * patch + + img_size, roi_size, overlap, sw_batch_size = size_params + img_device, device, sw_device = device_params + dtype = [torch.float, torch.double][roi_size[0] % 2] # test different input dtype + mode = ["constant", "gaussian"][img_size[1] % 2] + image = torch.randint(0, 255, size=img_size, dtype=dtype, device=img_device) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + sw = sliding_window_inference( + image, + roi_size, + sw_batch_size, + mult_two, + overlap, + mode=mode, + sw_device=sw_device, + device=device, + buffer_steps=buffer_steps, + buffer_dim=buffer_dim, + ) + if torch.cuda.is_available(): + mem_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0.0) / 1024**3 + self.assertGreater(0.8, mem_peak) # less than 1GB + max_diff = torch.max(torch.abs(image.to(sw) - 0.5 * sw)).item() + self.assertGreater(0.001, max_diff) + @parameterized.expand(TEST_CASES) def test_sliding_window_default(self, image_shape, roi_shape, sw_batch_size, overlap, mode, device): n_total = np.prod(image_shape) From 3bb4a53684c12d4ad4ba6b50b259bd1d266d581f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 1 Apr 2023 16:38:36 +0100 Subject: [PATCH 13/30] update type Signed-off-by: Wenqi Li --- tests/test_sliding_window_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index 5e445a2789..7c8c7041ca 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -61,7 +61,7 @@ [(2, 1, 10, 512, 200), (96, 97, 98), (0.4, 0.12, 0), 21], ] -BUFFER_CASES = [] +BUFFER_CASES: list = [] for x in _windows: for s in (1, 3, 5): for d in (-1, 0, 1): From f56888d5812f4c0f85bb3cf3a7dbf4de81aabe51 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 1 Apr 2023 11:50:35 -0400 Subject: [PATCH 14/30] mmore tests Signed-off-by: Wenqi Li --- tests/test_sliding_window_inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index 7c8c7041ca..c1a8368dcb 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -59,11 +59,12 @@ _windows += [ [(2, 1, 125, 512, 200), (96, 97, 98), (0.4, 0.32, 0), 20], [(2, 1, 10, 512, 200), (96, 97, 98), (0.4, 0.12, 0), 21], + [(2, 3, 100, 100, 200), (50, 50, 100), 0, 8], ] BUFFER_CASES: list = [] for x in _windows: - for s in (1, 3, 5): + for s in (1, 3, 4): for d in (-1, 0, 1): BUFFER_CASES.extend([x, s, d, dev] for dev in itertools.product(*_devices * 3)) From bd42c16fd8b821f0076bcaac28a29fab7bce9af2 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 1 Apr 2023 12:32:02 -0400 Subject: [PATCH 15/30] update Signed-off-by: Wenqi Li --- tests/test_sliding_window_inference.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index c1a8368dcb..117ad341c5 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -80,9 +80,6 @@ def mult_two(patch, *args, **kwargs): dtype = [torch.float, torch.double][roi_size[0] % 2] # test different input dtype mode = ["constant", "gaussian"][img_size[1] % 2] image = torch.randint(0, 255, size=img_size, dtype=dtype, device=img_device) - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() sw = sliding_window_inference( image, roi_size, @@ -95,9 +92,6 @@ def mult_two(patch, *args, **kwargs): buffer_steps=buffer_steps, buffer_dim=buffer_dim, ) - if torch.cuda.is_available(): - mem_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0.0) / 1024**3 - self.assertGreater(0.8, mem_peak) # less than 1GB max_diff = torch.max(torch.abs(image.to(sw) - 0.5 * sw)).item() self.assertGreater(0.001, max_diff) From 8cd28e9445e17db6c20e85c0dba3bf9ab5b20227 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 1 Apr 2023 13:45:51 -0400 Subject: [PATCH 16/30] nonblocking copy Signed-off-by: Wenqi Li --- monai/inferers/utils.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index a2d8c4b40e..85a60b09ea 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -174,11 +174,19 @@ def sliding_window_inference( total_slices = num_win * batch_size # total number of windows windows_range: Iterable if not buffered: + non_blocking = False windows_range = range(0, total_slices, sw_batch_size) else: slices, n_per_batch, b_slices, windows_range = _create_buffered_slices( slices, batch_size, sw_batch_size, buffer_dim, buffer_steps ) + non_blocking = buffered and overlap[buffer_dim] == 0 and torch.device(sw_device).type == "cuda" + _ss = -1 + for x in b_slices[:n_per_batch]: + if x[1] < _ss: # no overlapping slices + non_blocking = False + break + _ss = x[2] # Create window-level importance map valid_patch_size = get_valid_patch_size(image_size, roi_size) @@ -259,7 +267,10 @@ def sliding_window_inference( o_slice[buffer_dim + 2] = slice(c_start, c_end) img_b = b_s // n_per_batch # image batch index o_slice[0] = slice(img_b, img_b + 1) - output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device) + if non_blocking: + output_image_list[0][o_slice].copy_(sw_device_buffer[0], non_blocking=non_blocking) + else: + output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device) else: sw_device_buffer[ss] *= w_t sw_device_buffer[ss] = sw_device_buffer[ss].to(device) @@ -268,6 +279,9 @@ def sliding_window_inference( if buffered: b_s += 1 + if non_blocking: + torch.cuda.current_stream().synchronize() + # account for any overlapping sections for ss in range(len(output_image_list)): output_image_list[ss] /= count_map_list.pop(0) From 10d665f77b17172f0f96563c212171737bc35268 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 1 Apr 2023 19:27:51 +0100 Subject: [PATCH 17/30] nonblocking copy Signed-off-by: Wenqi Li --- monai/inferers/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 85a60b09ea..4e3f103352 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -183,7 +183,7 @@ def sliding_window_inference( non_blocking = buffered and overlap[buffer_dim] == 0 and torch.device(sw_device).type == "cuda" _ss = -1 for x in b_slices[:n_per_batch]: - if x[1] < _ss: # no overlapping slices + if x[1] < _ss: # detect overlapping slices non_blocking = False break _ss = x[2] @@ -255,7 +255,8 @@ def sliding_window_inference( output_shape = [batch_size, seg_chns] output_shape += [int(_i * _z) for _i, _z in zip(image_size, z_scale)] if z_scale else list(image_size) # allocate memory to store the full output and the count for overlapping parts - output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device=device)) + new_tensor: Callable = torch.empty if non_blocking else torch.zeros + output_image_list.append(new_tensor(output_shape, dtype=compute_dtype, device=device)) count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device)) w_t_ = w_t.to(device) for __s in slices: From f5211dff42d36dbd042415a2784e76d31ea950ee Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 1 Apr 2023 19:37:20 +0100 Subject: [PATCH 18/30] docs Signed-off-by: Wenqi Li --- monai/inferers/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 4e3f103352..8c31f4e329 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -119,7 +119,8 @@ def sliding_window_inference( If not given, and ``mode`` is not `constant`, this map will be computed on the fly. process_fn: process inference output and adjust the importance map per window buffer_steps: the number of sliding window iterations before writing the outputs to ``device``. - default is None, no buffering. + default is None, no buffering. For the buffer dim, when spatial size is divisible by buffer_steps*roi_size, + non_blocking copy may be automatically enabled for efficient processing. buffer_dim: the spatial dimension along which the buffer are created. 0 indicates the first spatial dimension. Default is -1, the last spatial dimension. args: optional args to be passed to ``predictor``. @@ -255,7 +256,7 @@ def sliding_window_inference( output_shape = [batch_size, seg_chns] output_shape += [int(_i * _z) for _i, _z in zip(image_size, z_scale)] if z_scale else list(image_size) # allocate memory to store the full output and the count for overlapping parts - new_tensor: Callable = torch.empty if non_blocking else torch.zeros + new_tensor: Callable = torch.empty if non_blocking else torch.zeros # type: ignore output_image_list.append(new_tensor(output_shape, dtype=compute_dtype, device=device)) count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device)) w_t_ = w_t.to(device) From e85efe93899225e066c721c1f58c5d0c96b4e8b9 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 2 Apr 2023 15:20:41 +0100 Subject: [PATCH 19/30] simplify slice1 Signed-off-by: Wenqi Li --- monai/inferers/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 8c31f4e329..83df59b193 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -199,6 +199,8 @@ def sliding_window_inference( importance_map_ = compute_importance_map( valid_p_size, mode=mode, sigma_scale=sigma_scale, device=sw_device, dtype=compute_dtype ) + if len(importance_map_.shape) == (num_spatial_dims - 2): + importance_map_ = importance_map_[None, None] except Exception as e: raise RuntimeError( f"patch size {valid_p_size}, mode={mode}, sigma_scale={sigma_scale}, device={device}\n" @@ -215,7 +217,10 @@ def sliding_window_inference( [slice(idx // num_win, idx // num_win + 1), slice(None)] + list(slices[idx % num_win]) for idx in slice_range ] - win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) + if len(unravel_slice) > 1: + win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) + else: + win_data = inputs[unravel_slice[0]].to(sw_device) seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch # convert seg_prob_out to tuple seg_tuple, this does not allocate new memory. From e879a9c30c123bb0cfbd29838f5e1665c9d792fa Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 2 Apr 2023 16:09:15 +0100 Subject: [PATCH 20/30] update Signed-off-by: Wenqi Li --- monai/inferers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 823c540274..4cfc7a980b 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -181,7 +181,7 @@ def sliding_window_inference( slices, n_per_batch, b_slices, windows_range = _create_buffered_slices( slices, batch_size, sw_batch_size, buffer_dim, buffer_steps ) - non_blocking = buffered and torch.device(device).type == "cuda" + non_blocking = buffered _ss = -1 for x in b_slices[:n_per_batch]: if x[1] < _ss: # detect overlapping slices From d787eef797389313fe9390dcb54859606d1d17d6 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 2 Apr 2023 16:14:53 +0100 Subject: [PATCH 21/30] simplify non-blocking flag when buffered=True Signed-off-by: Wenqi Li --- monai/inferers/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 4cfc7a980b..0c3b1a2b85 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -181,8 +181,7 @@ def sliding_window_inference( slices, n_per_batch, b_slices, windows_range = _create_buffered_slices( slices, batch_size, sw_batch_size, buffer_dim, buffer_steps ) - non_blocking = buffered - _ss = -1 + non_blocking, _ss = True, -1 for x in b_slices[:n_per_batch]: if x[1] < _ss: # detect overlapping slices non_blocking = False From 545ad5ecc384ca34a1836149feb64dfc3bc639ff Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 2 Apr 2023 16:24:58 +0100 Subject: [PATCH 22/30] fixes no cuda Signed-off-by: Wenqi Li --- monai/inferers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 0c3b1a2b85..2aa6ef8741 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -181,7 +181,7 @@ def sliding_window_inference( slices, n_per_batch, b_slices, windows_range = _create_buffered_slices( slices, batch_size, sw_batch_size, buffer_dim, buffer_steps ) - non_blocking, _ss = True, -1 + non_blocking, _ss = torch.cuda.is_available(), -1 for x in b_slices[:n_per_batch]: if x[1] < _ss: # detect overlapping slices non_blocking = False From b55b892433adbfe15affbfd18c499baae28ffd35 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 3 Apr 2023 03:48:23 -0400 Subject: [PATCH 23/30] update docs Signed-off-by: Wenqi Li --- monai/inferers/utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 2aa6ef8741..46ba814a5b 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -118,10 +118,11 @@ def sliding_window_inference( roi_weight_map: pre-computed (non-negative) weight map for each ROI. If not given, and ``mode`` is not `constant`, this map will be computed on the fly. process_fn: process inference output and adjust the importance map per window - buffer_steps: the number of sliding window iterations before writing the outputs to ``device``. + buffer_steps: the number of sliding window iterations along the ``buffer_dim`` + to be buffered on ``sw_device`` before writing to ``device``. default is None, no buffering. For the buffer dim, when spatial size is divisible by buffer_steps*roi_size, - non_blocking copy may be automatically enabled for efficient processing. - buffer_dim: the spatial dimension along which the buffer are created. + (i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency. + buffer_dim: the spatial dimension along which the buffers are created. 0 indicates the first spatial dimension. Default is -1, the last spatial dimension. args: optional args to be passed to ``predictor``. kwargs: optional keyword args to be passed to ``predictor``. @@ -135,8 +136,6 @@ def sliding_window_inference( if buffered: if buffer_dim < -num_spatial_dims or buffer_dim > num_spatial_dims: raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.") - if buffer_steps <= 0: # type: ignore - raise ValueError(f"buffer_steps must be >= 0, got {buffer_steps}.") if buffer_dim < 0: buffer_dim += num_spatial_dims overlap = ensure_tuple_rep(overlap, num_spatial_dims) From 1ac4f78b26cc172e9e24b0000e8667eb0bcce670 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 3 Apr 2023 11:11:09 +0100 Subject: [PATCH 24/30] prepare weight map dims Signed-off-by: Wenqi Li --- monai/inferers/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 46ba814a5b..90fc3cd1dc 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -197,8 +197,8 @@ def sliding_window_inference( importance_map_ = compute_importance_map( valid_p_size, mode=mode, sigma_scale=sigma_scale, device=sw_device, dtype=compute_dtype ) - if len(importance_map_.shape) == (num_spatial_dims - 2): - importance_map_ = importance_map_[None, None] + if len(importance_map_.shape) == num_spatial_dims: + importance_map_ = importance_map_[None, None] # adds batch, channel dimensions except Exception as e: raise RuntimeError( f"patch size {valid_p_size}, mode={mode}, sigma_scale={sigma_scale}, device={device}\n" From 2c8248d3fd6cf72c4d0f9511ae523aa6fa2a30a8 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 3 Apr 2023 11:45:39 +0100 Subject: [PATCH 25/30] compatible sliding_window_hovernet_inference Signed-off-by: Wenqi Li --- monai/inferers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 90fc3cd1dc..deb754b03a 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -197,7 +197,7 @@ def sliding_window_inference( importance_map_ = compute_importance_map( valid_p_size, mode=mode, sigma_scale=sigma_scale, device=sw_device, dtype=compute_dtype ) - if len(importance_map_.shape) == num_spatial_dims: + if len(importance_map_.shape) == num_spatial_dims and not process_fn: importance_map_ = importance_map_[None, None] # adds batch, channel dimensions except Exception as e: raise RuntimeError( From ab89dd9e87a6012a25818303d78c9dc38d9e002b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 5 Apr 2023 11:45:48 +0100 Subject: [PATCH 26/30] simplify Signed-off-by: Wenqi Li --- monai/inferers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index deb754b03a..59fb479904 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -254,7 +254,7 @@ def sliding_window_inference( z_scale = None if not buffered and seg_shape != roi_size: z_scale = [out_w_i / float(in_w_i) for out_w_i, in_w_i in zip(seg_shape, roi_size)] - w_t = torch.nn.functional.interpolate(w_t, seg_shape, mode=_nearest_mode) + w_t = F.interpolate(w_t, seg_shape, mode=_nearest_mode) if len(output_image_list) <= ss: output_shape = [batch_size, seg_chns] output_shape += [int(_i * _z) for _i, _z in zip(image_size, z_scale)] if z_scale else list(image_size) From e606d5589739cdc007eb0bc0462a0217cfc2a86f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 5 Apr 2023 21:19:33 +0100 Subject: [PATCH 27/30] remove existing Signed-off-by: Wenqi Li --- .github/workflows/integration.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 456fa10c41..14e269cc74 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -2,6 +2,9 @@ name: integration on: + push: + branches: + - temp-tests repository_dispatch: type: [integration-test-command] @@ -59,6 +62,7 @@ jobs: # test latest template cd ../ + rm -rf research-contributions git clone --depth 1 --branch main --single-branch https://github.com/Project-MONAI/research-contributions.git ls research-contributions/ cp -r research-contributions/auto3dseg/algorithm_templates ../MONAI/ From d77bb90763416e17903050e7f8fba7548e864d05 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 5 Apr 2023 21:56:04 +0100 Subject: [PATCH 28/30] update Signed-off-by: Wenqi Li --- .github/workflows/integration.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 14e269cc74..469e58b9cb 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -54,6 +54,7 @@ jobs: python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))' # test auto3dseg + echo "test tag algo" BUILD_MONAI=0 ./runtests.sh --build python -m tests.test_auto3dseg_ensemble python -m tests.test_auto3dseg_hpo @@ -61,6 +62,7 @@ jobs: python -m tests.test_integration_gpu_customization # test latest template + echo "test latest algo" cd ../ rm -rf research-contributions git clone --depth 1 --branch main --single-branch https://github.com/Project-MONAI/research-contributions.git @@ -76,6 +78,9 @@ jobs: python -m tests.test_integration_gpu_customization # the other tests + echo "the other tests" + pwd + ls -ll BUILD_MONAI=1 ./runtests.sh --build --net BUILD_MONAI=1 ./runtests.sh --build --unittests if pgrep python; then pkill python; fi From b4a8a9a8f5c6a27c37bd707c17d871f0bd26ee72 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 5 Apr 2023 22:34:44 +0100 Subject: [PATCH 29/30] update Signed-off-by: Wenqi Li --- .github/workflows/integration.yml | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 469e58b9cb..41d5fbda49 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -37,6 +37,10 @@ jobs: run: | which python python -m pip install --upgrade pip wheel + pip uninstall -y monai + pip uninstall -y monai + pip uninstall -y monai-weekly + pip uninstall -y monai-weekly python -m pip install --upgrade torch torchvision torchaudio python -m pip install -r requirements-dev.txt rm -rf /github/home/.cache/torch/hub/mmars/ @@ -65,10 +69,13 @@ jobs: echo "test latest algo" cd ../ rm -rf research-contributions + rm -rf algorithm_templates git clone --depth 1 --branch main --single-branch https://github.com/Project-MONAI/research-contributions.git ls research-contributions/ - cp -r research-contributions/auto3dseg/algorithm_templates ../MONAI/ - cd research-contributions && git log -1 && cd .. + cp -r research-contributions/auto3dseg/algorithm_templates MONAI/ + cd research-contributions && git log -1 && cd ../MONAI + pwd + ls -ll export OMP_NUM_THREADS=4 export MKL_NUM_THREADS=4 export MONAI_TESTING_ALGO_TEMPLATE=algorithm_templates From a0cf13499678b112880e67a61dfc45d81a062bdb Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 5 Apr 2023 22:37:38 +0100 Subject: [PATCH 30/30] update Signed-off-by: Wenqi Li --- .github/workflows/integration.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 41d5fbda49..8ae028bbd2 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -2,9 +2,6 @@ name: integration on: - push: - branches: - - temp-tests repository_dispatch: type: [integration-test-command]