-
Notifications
You must be signed in to change notification settings - Fork 1.4k
39 sliding window inference #52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| # Copyright 2020 MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import numpy as np | ||
| import monai | ||
|
|
||
| export = monai.utils.export("monai.data.transforms") | ||
|
|
||
|
|
||
| @export | ||
| class ImageEndPadder: | ||
| """Performs padding by appending to the end of the data all on one side for each dimension. | ||
| Uses np.pad so in practice, a mode needs to be provided. See numpy.lib.arraypad.pad | ||
| for additional details. | ||
|
|
||
| Args: | ||
| out_size (list): the size of region of interest at the end of the operation. | ||
| mode (string): a portion from numpy.lib.arraypad.pad is copied below. | ||
| dtype: output data format. | ||
| """ | ||
|
|
||
| def __init__(self, out_size, mode, dtype=np.float32): | ||
| assert out_size is not None and isinstance(out_size, (list, tuple)), 'out_size must be list or tuple' | ||
| self.out_size = out_size | ||
| assert isinstance(mode, str), 'mode must be str' | ||
| self.mode = mode | ||
| self.dtype = dtype | ||
|
|
||
| def _determine_data_pad_width(self, data_shape): | ||
| return [(0, max(self.out_size[i] - data_shape[i], 0)) for i in range(len(self.out_size))] | ||
|
|
||
| def __call__(self, img): | ||
| data_pad_width = self._determine_data_pad_width(img.shape[2:]) | ||
| all_pad_width = [(0, 0), (0, 0)] + data_pad_width | ||
| img = np.pad(img, all_pad_width, self.mode) | ||
| return img |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,165 @@ | ||
| # Copyright 2020 MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
|
|
||
| import math | ||
| import numpy as np | ||
| import torch | ||
| from monai.data.transforms import ImageEndPadder | ||
|
|
||
|
|
||
| def sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, device): | ||
| """Use SlidingWindow method to execute inference. | ||
|
|
||
| Args: | ||
| roi_size (list, tuple): the window size to execute SlidingWindow inference. | ||
| sw_batch_size (int): the batch size to run window slices. | ||
| predictor: a portion from numpy.lib.arraypad.pad is copied below. | ||
| device: on which device to execute model inference, cpu or gpu. | ||
|
|
||
| Note: | ||
| must be channel first, support both 2D and 3D. | ||
| input data must have batch dim. | ||
| execute on 1 image/per inference, run a batch of window slices of 1 input image. | ||
| """ | ||
| num_spatial_dims = len(inputs.shape) - 2 | ||
| assert len(roi_size) == num_spatial_dims, 'roi_size {} does not match input dims.'.format(roi_size) | ||
|
|
||
| # determine image spatial size and batch size | ||
| # Note: all input images must have the same image size and batch size | ||
Nic-Ma marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| image_size = list(inputs.shape[2:]) | ||
| batch_size = inputs.shape[0] | ||
|
|
||
| # TODO: Enable batch sizes > 1 in future. | ||
| assert batch_size == 1, "input batch size must be 1" | ||
|
|
||
| # in case that image size is smaller than roi size | ||
| is_oversized = False | ||
| original_image_size = [image_size[i] for i in range(num_spatial_dims)] | ||
|
|
||
| for i in range(num_spatial_dims): | ||
| if int(roi_size[i]) > image_size[i]: | ||
| is_oversized = True | ||
| break | ||
|
|
||
| if is_oversized: | ||
| for i in range(num_spatial_dims): | ||
| image_size[i] = max(image_size[i], roi_size[i]) | ||
|
|
||
| padder = ImageEndPadder(roi_size, 'constant') | ||
| inputs = padder(inputs) | ||
|
|
||
| scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims) | ||
| scan_num = [int(math.ceil(float(image_size[i]) / scan_interval[i])) for i in range(num_spatial_dims)] | ||
|
|
||
| # Store all slices in list. | ||
| slices = [] | ||
| if num_spatial_dims == 3: | ||
wyli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| for i in range(scan_num[0]): | ||
| start_i = i * scan_interval[0] | ||
| start_i -= max(start_i + roi_size[0] - image_size[0], 0) | ||
| slice_i = slice(start_i, start_i + roi_size[0]) | ||
|
|
||
| for j in range(scan_num[1]): | ||
| start_j = j * scan_interval[1] | ||
| start_j -= max(start_j + roi_size[1] - image_size[1], 0) | ||
| slice_j = slice(start_j, start_j + roi_size[1]) | ||
|
|
||
| for k in range(0, scan_num[2]): | ||
| start_k = k * scan_interval[2] | ||
| start_k -= max(start_k + roi_size[2] - image_size[2], 0) | ||
| slice_k = slice(start_k, start_k + roi_size[2]) | ||
| slices.append((slice_i, slice_j, slice_k)) | ||
| else: | ||
| for i in range(scan_num[0]): | ||
| start_i = i * scan_interval[0] | ||
| start_i -= max(start_i + roi_size[0] - image_size[0], 0) | ||
| slice_i = slice(start_i, start_i + roi_size[0]) | ||
|
|
||
| for j in range(scan_num[1]): | ||
| start_j = j * scan_interval[1] | ||
| start_j -= max(start_j + roi_size[1] - image_size[1], 0) | ||
| slice_j = slice(start_j, start_j + roi_size[1]) | ||
| slices.append((slice_i, slice_j)) | ||
|
|
||
| buffered_requests = [] | ||
Nic-Ma marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| for slice_index in range(0, len(slices), sw_batch_size): | ||
| slice_index_range = range(slice_index, min(slice_index + sw_batch_size, len(slices))) | ||
| input_slices = [] | ||
| for curr_index in slice_index_range: | ||
| if num_spatial_dims == 3: | ||
| slice_i, slice_j, slice_k = slices[curr_index] | ||
| input_slices.append(inputs[0, :, slice_i, slice_j, slice_k]) | ||
| else: | ||
| slice_i, slice_j = slices[curr_index] | ||
| input_slices.append(inputs[0, :, slice_i, slice_j]) | ||
| buffered_requests.append(np.stack(input_slices)) | ||
|
|
||
| # Perform predictions | ||
| output_rois = list() | ||
| for data in buffered_requests: | ||
| output_rois.append(predictor(data)) | ||
|
|
||
| output_classes = output_rois[0].shape[1] | ||
| output_shape = [batch_size, output_classes] + list(image_size) | ||
|
|
||
| # allocate memory to store the full output and the count for overlapping parts | ||
| output_dict = torch.zeros(output_shape, dtype=torch.float32, device=device) | ||
| count_dict = torch.zeros(output_shape, dtype=torch.float32, device=device) | ||
|
|
||
| window_index = 0 | ||
| for slice_index in range(0, len(slices), sw_batch_size): | ||
| slice_index_range = range(slice_index, min(slice_index + sw_batch_size, len(slices))) | ||
| output_roi = output_rois[window_index] | ||
| window_index += 1 | ||
|
|
||
| # store the result in the proper location of the full output | ||
| for curr_index in slice_index_range: | ||
| if num_spatial_dims == 3: | ||
| slice_i, slice_j, slice_k = slices[curr_index] | ||
| output_dict[0, :, slice_i, slice_j, slice_k] += \ | ||
| output_roi[curr_index - slice_index, :] | ||
| count_dict[0, :, slice_i, slice_j, slice_k] += 1 | ||
| else: | ||
| slice_i, slice_j = slices[curr_index] | ||
| output_dict[0, :, slice_i, slice_j] += \ | ||
| output_roi[curr_index - slice_index, :] | ||
| count_dict[0, :, slice_i, slice_j] += 1 | ||
|
|
||
| # account for any overlapping sections | ||
| output_dict /= count_dict | ||
|
|
||
| # in case that image size is smaller than roi size | ||
| if is_oversized: | ||
| new_output_dict = list() | ||
| if num_spatial_dims == 3: | ||
| new_output_dict = output_dict[:, :, :original_image_size[0], | ||
| :original_image_size[1], :original_image_size[2]] | ||
| else: | ||
| new_output_dict = output_dict[:, :, :original_image_size[0], :original_image_size[1]] | ||
|
|
||
| output_dict = new_output_dict | ||
|
|
||
| return output_dict | ||
|
|
||
|
|
||
| def _get_scan_interval(image_size, roi_size, num_spatial_dims): | ||
| assert (len(image_size) == num_spatial_dims), 'image coord different from spatial dims.' | ||
| assert (len(roi_size) == num_spatial_dims), 'roi coord different from spatial dims.' | ||
|
|
||
| scan_interval = [1 for _ in range(num_spatial_dims)] | ||
| for i in range(num_spatial_dims): | ||
| if roi_size[i] == image_size[i]: | ||
| scan_interval[i] = int(roi_size[i]) | ||
| else: | ||
| # this means that it's r-16 (if r>=64) and r*0.75 (if r<=64) | ||
| scan_interval[i] = int(max(roi_size[i] - 16, roi_size[i] * 0.75)) | ||
| return scan_interval | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| # Copyright 2020 MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import unittest | ||
| import numpy as np | ||
| from parameterized import parameterized | ||
| from monai.data.transforms import ImageEndPadder | ||
|
|
||
| TEST_CASE_1 = [ | ||
| { | ||
| 'out_size': [16, 16, 8], | ||
| 'mode': 'constant' | ||
| }, | ||
| np.zeros((1, 3, 8, 8, 4)), | ||
| np.zeros((1, 3, 16, 16, 8)), | ||
| ] | ||
|
|
||
| class TestImageEndPadder(unittest.TestCase): | ||
|
|
||
| @parameterized.expand([TEST_CASE_1]) | ||
| def test_image_end_pad_shape(self, input_param, input_data, expected_val): | ||
| padder = ImageEndPadder(**input_param) | ||
| result = padder(input_data) | ||
| self.assertAlmostEqual(result.shape, expected_val.shape) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| # Copyright 2020 MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import unittest | ||
| import torch | ||
| import numpy as np | ||
|
|
||
| from monai.utils.sliding_window_inference import sliding_window_inference | ||
|
|
||
|
|
||
| class TestSlidingWindowInference(unittest.TestCase): | ||
|
|
||
| def test_sliding_window_default(self): | ||
| inputs = np.ones((1, 3, 16, 16, 8)) | ||
| roi_size = [4, 4, 4] | ||
| sw_batch_size = 4 | ||
| device = torch.device("cuda:0") | ||
|
|
||
| def compute(data): | ||
| data = torch.from_numpy(data) | ||
| return data.to(device) + 1 | ||
|
|
||
| result = sliding_window_inference(inputs, roi_size, sw_batch_size, compute, device) | ||
| expected_val = torch.ones((1, 3, 16, 16, 8), dtype=torch.float32, device=device) + 1 | ||
| self.assertAlmostEqual(result.shape, expected_val.shape) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.