diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 7fa7b9e9ff..cc4cea8c1e 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -248,6 +248,22 @@ FastMRIReader ~~~~~~~~~~~~~ .. autofunction:: monai.apps.reconstruction.complex_utils.complex_conj +`Vista3d` +--------- +.. automodule:: monai.apps.vista3d.inferer +.. autofunction:: point_based_window_inferer + +.. automodule:: monai.apps.vista3d.transforms +.. autoclass:: VistaPreTransformd + :members: +.. autoclass:: VistaPostTransformd + :members: +.. autoclass:: Relabeld + :members: + +.. automodule:: monai.apps.vista3d.sampler +.. autofunction:: sample_prompt_pairs + `Auto3DSeg` ----------- .. automodule:: monai.apps.auto3dseg diff --git a/monai/apps/generation/maisi/utils/__init__.py b/monai/apps/vista3d/__init__.py similarity index 100% rename from monai/apps/generation/maisi/utils/__init__.py rename to monai/apps/vista3d/__init__.py diff --git a/monai/apps/vista3d/inferer.py b/monai/apps/vista3d/inferer.py new file mode 100644 index 0000000000..709f81f624 --- /dev/null +++ b/monai/apps/vista3d/inferer.py @@ -0,0 +1,177 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +from collections.abc import Sequence +from typing import Any + +import torch + +from monai.data.meta_tensor import MetaTensor +from monai.utils import optional_import + +tqdm, _ = optional_import("tqdm", name="tqdm") + +__all__ = ["point_based_window_inferer"] + + +def point_based_window_inferer( + inputs: torch.Tensor | MetaTensor, + roi_size: Sequence[int], + predictor: torch.nn.Module, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + class_vector: torch.Tensor | None = None, + prompt_class: torch.Tensor | None = None, + prev_mask: torch.Tensor | MetaTensor | None = None, + point_start: int = 0, + center_only: bool = True, + margin: int = 5, + **kwargs: Any, +) -> torch.Tensor: + """ + Point-based window inferer that takes an input image, a set of points, and a model, and returns a segmented image. + The inferer algorithm crops the input image into patches that centered at the point sets, which is followed by + patch inference and average output stitching, and finally returns the segmented mask. + + Args: + inputs: [1CHWD], input image to be processed. + roi_size: the spatial window size for inferences. + When its components have None or non-positives, the corresponding inputs dimension will be used. + if the components of the `roi_size` are non-positive values, the transform will use the + corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + sw_batch_size: the batch size to run window slices. + predictor: the model. For vista3D, the output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D]. + Add transpose=True in kwargs for vista3d. + point_coords: [B, N, 3]. Point coordinates for B foreground objects, each has N points. + point_labels: [B, N]. Point labels. 0/1 means negative/positive points for regular supported or zero-shot classes. + 2/3 means negative/positive points for special supported classes (e.g. tumor, vessel). + class_vector: [B]. Used for class-head automatic segmentation. Can be None value. + prompt_class: [B]. The same as class_vector representing the point class and inform point head about + supported class or zeroshot, not used for automatic segmentation. If None, point head is default + to supported class segmentation. + prev_mask: [1, B, H, W, D]. The value is before sigmoid. An optional tensor of previously segmented masks. + point_start: only use points starting from this number. All points before this number is used to generate + prev_mask. This is used to avoid re-calculating the points in previous iterations if given prev_mask. + center_only: for each point, only crop the patch centered at this point. If false, crop 3 patches for each point. + margin: if center_only is false, this value is the distance between point to the patch boundary. + Returns: + stitched_output: [1, B, H, W, D]. The value is before sigmoid. + Notice: The function only supports SINGLE OBJECT INFERENCE with B=1. + """ + if not point_coords.shape[0] == 1: + raise ValueError("Only supports single object point click.") + if not len(inputs.shape) == 5: + raise ValueError("Input image should be 5D.") + image, pad = _pad_previous_mask(copy.deepcopy(inputs), roi_size) + point_coords = point_coords + torch.tensor([pad[-2], pad[-4], pad[-6]]).to(point_coords.device) + prev_mask = _pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] if prev_mask is not None else None + stitched_output = None + for p in point_coords[0][point_start:]: + lx_, rx_ = _get_window_idx(p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=margin) + ly_, ry_ = _get_window_idx(p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=margin) + lz_, rz_ = _get_window_idx(p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=margin) + for i in range(len(lx_)): + for j in range(len(ly_)): + for k in range(len(lz_)): + lx, rx, ly, ry, lz, rz = (lx_[i], rx_[i], ly_[j], ry_[j], lz_[k], rz_[k]) + unravel_slice = [ + slice(None), + slice(None), + slice(int(lx), int(rx)), + slice(int(ly), int(ry)), + slice(int(lz), int(rz)), + ] + batch_image = image[unravel_slice] + output = predictor( + batch_image, + point_coords=point_coords, + point_labels=point_labels, + class_vector=class_vector, + prompt_class=prompt_class, + patch_coords=unravel_slice, + prev_mask=prev_mask, + **kwargs, + ) + if stitched_output is None: + stitched_output = torch.zeros( + [1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu" + ) + stitched_mask = torch.zeros( + [1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu" + ) + stitched_output[unravel_slice] += output.to("cpu") + stitched_mask[unravel_slice] = 1 + # if stitched_mask is 0, then NaN value + stitched_output = stitched_output / stitched_mask + # revert padding + stitched_output = stitched_output[ + :, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1] + ] + stitched_mask = stitched_mask[ + :, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1] + ] + if prev_mask is not None: + prev_mask = prev_mask[ + :, + :, + pad[4] : image.shape[-3] - pad[5], + pad[2] : image.shape[-2] - pad[3], + pad[0] : image.shape[-1] - pad[1], + ] + prev_mask = prev_mask.to("cpu") # type: ignore + # for un-calculated place, use previous mask + stitched_output[stitched_mask < 1] = prev_mask[stitched_mask < 1] + if isinstance(inputs, torch.Tensor): + inputs = MetaTensor(inputs) + if not hasattr(stitched_output, "meta"): + stitched_output = MetaTensor(stitched_output, affine=inputs.meta["affine"], meta=inputs.meta) + return stitched_output + + +def _get_window_idx_c(p: int, roi: int, s: int) -> tuple[int, int]: + """Helper function to get the window index.""" + if p - roi // 2 < 0: + left, right = 0, roi + elif p + roi // 2 > s: + left, right = s - roi, s + else: + left, right = int(p) - roi // 2, int(p) + roi // 2 + return left, right + + +def _get_window_idx(p: int, roi: int, s: int, center_only: bool = True, margin: int = 5) -> tuple[list[int], list[int]]: + """Get the window index.""" + left, right = _get_window_idx_c(p, roi, s) + if center_only: + return [left], [right] + left_most = max(0, p - roi + margin) + right_most = min(s, p + roi - margin) + left_list = [left_most, right_most - roi, left] + right_list = [left_most + roi, right_most, right] + return left_list, right_list + + +def _pad_previous_mask( + inputs: torch.Tensor | MetaTensor, roi_size: Sequence[int], padvalue: int = 0 +) -> tuple[torch.Tensor | MetaTensor, list[int]]: + """Helper function to pad inputs.""" + pad_size = [] + for k in range(len(inputs.shape) - 1, 1, -1): + diff = max(roi_size[k - 2] - inputs.shape[k], 0) + half = diff // 2 + pad_size.extend([half, diff - half]) + if any(pad_size): + inputs = torch.nn.functional.pad(inputs, pad=pad_size, mode="constant", value=padvalue) # type: ignore + return inputs, pad_size diff --git a/monai/apps/vista3d/sampler.py b/monai/apps/vista3d/sampler.py new file mode 100644 index 0000000000..b7aeb89a2e --- /dev/null +++ b/monai/apps/vista3d/sampler.py @@ -0,0 +1,172 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +import random +from collections.abc import Callable, Sequence +from typing import Any + +import numpy as np +import torch +from torch import Tensor + +__all__ = ["sample_prompt_pairs"] + +ENABLE_SPECIAL = True +SPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128) +MERGE_LIST = { + 1: [25, 26], # hepatic tumor and vessel merge into liver + 4: [24], # pancreatic tumor merge into pancreas + 132: [57], # overlap with trachea merge into airway +} + + +def _get_point_label(id: int) -> tuple[int, int]: + if id in SPECIAL_INDEX and ENABLE_SPECIAL: + return 2, 3 + else: + return 0, 1 + + +def sample_prompt_pairs( + labels: Tensor, + label_set: Sequence[int], + max_prompt: int | None = None, + max_foreprompt: int | None = None, + max_backprompt: int = 1, + max_point: int = 20, + include_background: bool = False, + drop_label_prob: float = 0.2, + drop_point_prob: float = 0.2, + point_sampler: Callable | None = None, + **point_sampler_kwargs: Any, +) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]: + """ + Sample training pairs for VISTA3D training. + + Args: + labels: [1, 1, H, W, D], ground truth labels. + label_set: the label list for the specific dataset. Note if 0 is included in label_set, + it will be added into automatic branch training. Recommend removing 0 from label_set + for multi-partially-labeled-dataset training, and adding 0 for finetuning specific dataset. + The reason is region with 0 in one partially labeled dataset may contain foregrounds in + another dataset. + max_prompt: int, max number of total prompt, including foreground and background. + max_foreprompt: int, max number of prompt from foreground. + max_backprompt: int, max number of prompt from background. + max_point: maximum number of points for each object. + include_background: if include 0 into training prompt. If included, background 0 is treated + the same as foreground. Always be False for multi-partial-dataset training. If needed, + can be true for finetuning specific dataset, . + drop_label_prob: probability to drop label prompt. + drop_point_prob: probability to drop point prompt. + point_sampler: sampler to augment masks with supervoxel. + point_sampler_kwargs: arguments for point_sampler. + + Returns: + label_prompt: [B, 1]. The classes used for training automatic segmentation. + point: [B, N, 3]. The corresponding points for each class. + Note that background label prompt requires matching point as well ([0,0,0] is used). + point_label: [B, N]. The corresponding point labels for each point (negative or positive). + -1 is used for padding the background label prompt and will be ignored. + prompt_class: [B, 1], exactly the same with label_prompt for label indexing for training loss. + label_prompt can be None, and prompt_class is used to identify point classes. + """ + # class label number + if not labels.shape[0] == 1: + raise ValueError("only support batch size 1") + labels = labels[0, 0] + device = labels.device + unique_labels = labels.unique().cpu().numpy().tolist() + if include_background: + unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set))) + else: + unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set)) - {0}) + background_labels = list(set(label_set) - set(unique_labels)) + # during training, balance background and foreground prompts + if max_backprompt is not None: + if len(background_labels) > max_backprompt: + random.shuffle(background_labels) + background_labels = background_labels[:max_backprompt] + + if max_foreprompt is not None: + if len(unique_labels) > max_foreprompt: + random.shuffle(unique_labels) + unique_labels = unique_labels[:max_foreprompt] + + if max_prompt is not None: + if len(unique_labels) + len(background_labels) > max_prompt: + if len(unique_labels) > max_prompt: + unique_labels = random.sample(unique_labels, max_prompt) + background_labels = [] + else: + background_labels = random.sample(background_labels, max_prompt - len(unique_labels)) + _point = [] + _point_label = [] + # if use regular sampling + if point_sampler is None: + num_p = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))) + 1) + num_n = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2)))) + for id in unique_labels: + neg_id, pos_id = _get_point_label(id) + plabels = labels == int(id) + nlabels = ~plabels + plabelpoints = torch.nonzero(plabels) + nlabelpoints = torch.nonzero(nlabels) + # final sampled positive points + num_pa = min(len(plabelpoints), num_p) + # final sampled negative points + num_na = min(len(nlabelpoints), num_n) + _point.append( + torch.stack( + random.choices(plabelpoints, k=num_pa) + + random.choices(nlabelpoints, k=num_na) + + [torch.tensor([0, 0, 0], device=device)] * (num_p + num_n - num_pa - num_na) + ) + ) + _point_label.append( + torch.tensor([pos_id] * num_pa + [neg_id] * num_na + [-1] * (num_p + num_n - num_pa - num_na)).to( + device + ) + ) + for _ in background_labels: + # pad the background labels + _point.append(torch.zeros(num_p + num_n, 3).to(device)) # all 0 + _point_label.append(torch.zeros(num_p + num_n).to(device) - 1) # -1 not a point + else: + _point, _point_label = point_sampler(unique_labels, **point_sampler_kwargs) + for _ in background_labels: + # pad the background labels + _point.append(torch.zeros(len(_point_label[0]), 3).to(device)) # all 0 + _point_label.append(torch.zeros(len(_point_label[0])).to(device) - 1) # -1 not a point + if len(unique_labels) == 0 and len(background_labels) == 0: + # if max_backprompt is 0 and len(unique_labels), there is no effective prompt and the iteration must + # be skipped. Handle this in trainer. + label_prompt, point, point_label, prompt_class = None, None, None, None + else: + label_prompt = torch.tensor(unique_labels + background_labels).unsqueeze(-1).to(device).long() + point = torch.stack(_point) + point_label = torch.stack(_point_label) + prompt_class = copy.deepcopy(label_prompt) + if random.uniform(0, 1) < drop_label_prob and len(unique_labels) > 0: + label_prompt = None + # If label prompt is dropped, there is no need to pad with points with label -1. + pad = len(background_labels) + point = point[: len(point) - pad] # type: ignore + point_label = point_label[: len(point_label) - pad] + prompt_class = prompt_class[: len(prompt_class) - pad] + else: + if random.uniform(0, 1) < drop_point_prob: + point = None + point_label = None + return label_prompt, point, point_label, prompt_class diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py new file mode 100644 index 0000000000..3e8145cd80 --- /dev/null +++ b/monai/apps/vista3d/transforms.py @@ -0,0 +1,224 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from typing import Sequence + +import numpy as np +import torch + +from monai.config import DtypeLike, KeysCollection +from monai.transforms import MapLabelValue +from monai.transforms.transform import MapTransform +from monai.transforms.utils import keep_components_with_positive_points +from monai.utils import look_up_option + +__all__ = ["VistaPreTransformd", "VistaPostTransformd", "Relabeld"] + + +def _get_name_to_index_mapping(labels_dict: dict | None) -> dict: + """get the label name to index mapping""" + name_to_index_mapping = {} + if labels_dict is not None: + name_to_index_mapping = {v.lower(): int(k) for k, v in labels_dict.items()} + return name_to_index_mapping + + +def _convert_name_to_index(name_to_index_mapping: dict, label_prompt: list | None) -> list | None: + """convert the label name to index""" + if label_prompt is not None and isinstance(label_prompt, list): + converted_label_prompt = [] + # for new class, add to the mapping + for l in label_prompt: + if isinstance(l, str) and not l.isdigit(): + if l.lower() not in name_to_index_mapping: + name_to_index_mapping[l.lower()] = len(name_to_index_mapping) + for l in label_prompt: + if isinstance(l, (int, str)): + converted_label_prompt.append( + name_to_index_mapping.get(l.lower(), int(l) if l.isdigit() else 0) if isinstance(l, str) else int(l) + ) + else: + converted_label_prompt.append(l) + return converted_label_prompt + return label_prompt + + +class VistaPreTransformd(MapTransform): + def __init__( + self, + keys: KeysCollection, + allow_missing_keys: bool = False, + special_index: Sequence[int] = (25, 26, 27, 28, 29, 117), + labels_dict: dict | None = None, + subclass: dict | None = None, + ) -> None: + """ + Pre-transform for Vista3d. + + It performs two functionalities: + + 1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels), + convert point labels from 0 (negative), 1 (positive) to special 2 (negative), 3 (positive). + + 2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key]. + e.g. "lung" label is converted to ["left lung", "right lung"]. + + The `label_prompt` is a list of int values of length [B] and `point_labels` is a list of length B, + where each element is an int value of length [B, N]. + + Args: + keys: keys of the corresponding items to be transformed. + special_index: the index that defines the special class. + subclass: a dictionary that maps a label prompt to its subclasses. + allow_missing_keys: don't raise exception if key is missing. + """ + super().__init__(keys, allow_missing_keys) + self.special_index = special_index + self.subclass = subclass + self.name_to_index_mapping = _get_name_to_index_mapping(labels_dict) + + def __call__(self, data): + label_prompt = data.get("label_prompt", None) + point_labels = data.get("point_labels", None) + # convert the label name to index if needed + label_prompt = _convert_name_to_index(self.name_to_index_mapping, label_prompt) + try: + # The evaluator will check prompt. The invalid prompt will be skipped here and captured by evaluator. + if self.subclass is not None and label_prompt is not None: + _label_prompt = [] + subclass_keys = list(map(int, self.subclass.keys())) + for i in range(len(label_prompt)): + if label_prompt[i] in subclass_keys: + _label_prompt.extend(self.subclass[str(label_prompt[i])]) + else: + _label_prompt.append(label_prompt[i]) + data["label_prompt"] = _label_prompt + if label_prompt is not None and point_labels is not None: + if label_prompt[0] in self.special_index: + point_labels = np.array(point_labels) + point_labels[point_labels == 0] = 2 + point_labels[point_labels == 1] = 3 + point_labels = point_labels.tolist() + data["point_labels"] = point_labels + except Exception: + # There is specific requirements for `label_prompt` and `point_labels`. + # If B > 1 or `label_prompt` is in subclass_keys, `point_labels` must be None. + # Those formatting errors should be captured later. + warnings.warn("VistaPreTransformd failed to transform label prompt or point labels.") + + return data + + +class VistaPostTransformd(MapTransform): + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: + """ + Post-transform for Vista3d. It converts the model output logits into final segmentation masks. + If `label_prompt` is None, the output will be thresholded to be sequential indexes [0,1,2,...], + else the indexes will be [0, label_prompt[0], label_prompt[1], ...]. + If `label_prompt` is None while `points` are provided, the model will perform postprocess to remove + regions that does not contain positive points. + + Args: + keys: keys of the corresponding items to be transformed. + dataset_transforms: a dictionary specifies the transform for corresponding dataset: + key: dataset name, value: list of data transforms. + dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name". + allow_missing_keys: don't raise exception if key is missing. + + """ + super().__init__(keys, allow_missing_keys) + + def __call__(self, data): + """data["label_prompt"] should not contain 0""" + for keys in self.keys: + if keys in data: + pred = data[keys] + object_num = pred.shape[0] + device = pred.device + if data.get("label_prompt", None) is None and data.get("points", None) is not None: + pred = keep_components_with_positive_points( + pred.unsqueeze(0), + point_coords=data.get("points").to(device), + point_labels=data.get("point_labels").to(device), + )[0] + pred[pred < 0] = 0.0 + # if it's multichannel, perform argmax + if object_num > 1: + # concate background channel. Make sure user did not provide 0 as prompt. + is_bk = torch.all(pred <= 0, dim=0, keepdim=True) + pred = pred.argmax(0).unsqueeze(0).float() + 1.0 + pred[is_bk] = 0.0 + else: + # AsDiscrete will remove NaN + # pred = monai.transforms.AsDiscrete(threshold=0.5)(pred) + pred[pred > 0] = 1.0 + if "label_prompt" in data and data["label_prompt"] is not None: + pred += 0.5 # inplace mapping to avoid cloning pred + label_prompt = data["label_prompt"].to(device) # Ensure label_prompt is on the same device + for i in range(1, object_num + 1): + frac = i + 0.5 + pred[pred == frac] = label_prompt[i - 1].to(pred.dtype) + pred[pred == 0.5] = 0.0 + data[keys] = pred + return data + + +class Relabeld(MapTransform): + def __init__( + self, + keys: KeysCollection, + label_mappings: dict[str, list[tuple[int, int]]], + dtype: DtypeLike = np.int16, + dataset_key: str = "dataset_name", + allow_missing_keys: bool = False, + ) -> None: + """ + Remap the voxel labels in the input data dictionary based on the specified mapping. + + This list of local -> global label mappings will be applied to each input `data[keys]`. + if `data[dataset_key]` is not in `label_mappings`, label_mappings['default']` will be used. + if `label_mappings[data[dataset_key]]` is None, no relabeling will be performed. + + Args: + keys: keys of the corresponding items to be transformed. + label_mappings: a dictionary specifies how local dataset class indices are mapped to the + global class indices. The dictionary keys are dataset names and the values are lists of + list of (local label, global label) pairs. This list of local -> global label mappings + will be applied to each input `data[keys]`. If `data[dataset_key]` is not in `label_mappings`, + label_mappings['default']` will be used. if `label_mappings[data[dataset_key]]` is None, + no relabeling will be performed. Please set `label_mappings={}` to completely skip this transform. + dtype: convert the output data to dtype, default to float32. + dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name". + allow_missing_keys: don't raise exception if key is missing. + + """ + super().__init__(keys, allow_missing_keys) + self.mappers = {} + self.dataset_key = dataset_key + for name, mapping in label_mappings.items(): + self.mappers[name] = MapLabelValue( + orig_labels=[int(pair[0]) for pair in mapping], + target_labels=[int(pair[1]) for pair in mapping], + dtype=dtype, + ) + + def __call__(self, data): + d = dict(data) + dataset_name = d.get(self.dataset_key, "default") + _m = look_up_option(dataset_name, self.mappers, default=None) + if _m is None: + return d + for key in self.key_iterator(d): + d[key] = _m(d[key]) + return d diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index a080284e7c..bd99765348 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -300,6 +300,7 @@ def sliding_window_inference( # remove padding if image_size smaller than roi_size if any(pad_size): + kwargs.update({"pad_size": pad_size}) for ss, output_i in enumerate(output_image_list): zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)] final_slicing: list[slice] = [] diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index fe7f93d493..9148e36542 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -23,7 +23,7 @@ from monai.networks.blocks import MLPBlock, UnetrBasicBlock from monai.networks.nets import SegResNetDS2 from monai.transforms.utils import convert_points_to_disc -from monai.transforms.utils import get_largest_connected_component_mask_point as lcc +from monai.transforms.utils import keep_merge_components_with_points as lcc from monai.transforms.utils import sample_points_from_label from monai.utils import optional_import, unsqueeze_left, unsqueeze_right @@ -78,6 +78,35 @@ def __init__(self, image_encoder: nn.Module, class_head: nn.Module, point_head: self.NINF_VALUE = -9999 self.PINF_VALUE = 9999 + def update_slidingwindow_padding( + self, + pad_size: list | None, + labels: torch.Tensor | None, + prev_mask: torch.Tensor | None, + point_coords: torch.Tensor | None, + ): + """ + Image has been padded by sliding window inferer. + The related padding need to be performed outside of slidingwindow inferer. + + Args: + pad_size: padding size passed from sliding window inferer. + labels: image label ground truth. + prev_mask: previous segmentation mask. + point_coords: point click coordinates. + """ + if pad_size is None: + return labels, prev_mask, point_coords + if labels is not None: + labels = F.pad(labels, pad=pad_size, mode="constant", value=0) + if prev_mask is not None: + prev_mask = F.pad(prev_mask, pad=pad_size, mode="constant", value=0) + if point_coords is not None: + point_coords = point_coords + torch.tensor( + [pad_size[-2], pad_size[-4], pad_size[-6]], device=point_coords.device + ) + return labels, prev_mask, point_coords + def get_foreground_class_count(self, class_vector: torch.Tensor | None, point_coords: torch.Tensor | None) -> int: """Get number of foreground classes based on class and point prompt.""" if class_vector is None: @@ -317,6 +346,7 @@ def forward( prev_mask: torch.Tensor | None = None, radius: int | None = None, val_point_sampler: Callable | None = None, + transpose: bool = False, **kwargs, ): """ @@ -329,7 +359,7 @@ def forward( point_coords: [B, N, 3] point_labels: [B, N], -1 represents padding. 0/1 means negative/positive points for regular class. 2/3 means negative/postive ponits for special supported class like tumor. - class_vector: [B, 1], the global class index + class_vector: [B, 1], the global class index. prompt_class: [B, 1], the global class index. This value is associated with point_coords to identify if the points are for zero-shot or supported class. When class_vector and point_coords are both provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b] @@ -346,8 +376,12 @@ def forward( radius: single float value controling the gaussian blur when combining point and auto results. The gaussian combine is not used in VISTA3D training but might be useful for finetuning purposes. val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation. - + transpose: bool. If true, the output will be transposed to be [1, B, H, W, D]. Required to be true if calling from + sliding window inferer/point inferer. """ + labels, prev_mask, point_coords = self.update_slidingwindow_padding( + kwargs.get("pad_size", None), labels, prev_mask, point_coords + ) image_size = input_images.shape[-3:] device = input_images.device if point_coords is None and class_vector is None: @@ -424,9 +458,10 @@ def forward( point_labels, # type: ignore mapping_index, ) - if kwargs.get("keep_cache", False) and class_vector is None: self.image_embeddings = out.detach() + if transpose: + logits = logits.transpose(1, 0) return logits diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 363fce91be..7027c07d67 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -107,7 +107,8 @@ "generate_spatial_bounding_box", "get_extreme_points", "get_largest_connected_component_mask", - "get_largest_connected_component_mask_point", + "keep_merge_components_with_points", + "keep_components_with_positive_points", "convert_points_to_disc", "remove_small_objects", "img_bounds", @@ -1178,7 +1179,7 @@ def get_largest_connected_component_mask( return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0] -def get_largest_connected_component_mask_point( +def keep_merge_components_with_points( img_pos: NdarrayTensor, img_neg: NdarrayTensor, point_coords: NdarrayTensor, @@ -1188,8 +1189,8 @@ def get_largest_connected_component_mask_point( margins: int = 3, ) -> NdarrayTensor: """ - Gets the connected component of img_pos and img_neg that include the positive points and - negative points separately. The function is used for combining automatic results with interactive + Keep connected regions of img_pos and img_neg that include the positive points and + negative points separately. The function is used for merging automatic results with interactive results in VISTA3D. Args: @@ -1199,6 +1200,7 @@ def get_largest_connected_component_mask_point( neg_val: negative point label values. point_coords: the coordinates of each point, shape [B, N, 3], where N means the number of points. point_labels: the label of each point, shape [B, N]. + margins: include points outside of the region but within the margin. """ cucim_skimage, has_cucim = optional_import("cucim.skimage") @@ -1249,6 +1251,49 @@ def get_largest_connected_component_mask_point( return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0] +def keep_components_with_positive_points( + img: torch.Tensor, point_coords: torch.Tensor, point_labels: torch.Tensor +) -> torch.Tensor: + """ + Keep connected regions that include the positive points. Used for point-only inference postprocessing to remove + regions without positive points. + Args: + img: [1, B, H, W, D]. Output prediction from VISTA3D. Value is before sigmoid and contain NaN value. + point_coords: [B, N, 3]. Point click coordinates + point_labels: [B, N]. Point click labels. + """ + if not has_measure: + raise RuntimeError("skimage.measure required.") + outs = torch.zeros_like(img) + for c in range(len(point_coords)): + if not ((point_labels[c] == 3).any() or (point_labels[c] == 1).any()): + # skip if no positive points. + continue + coords = point_coords[c, point_labels[c] == 3].tolist() + point_coords[c, point_labels[c] == 1].tolist() + not_nan_mask = ~torch.isnan(img[0, c]) + img_ = torch.nan_to_num(img[0, c] > 0, 0) + img_, *_ = convert_data_type(img_, np.ndarray) # type: ignore + label = measure.label + features = label(img_, connectivity=3) + pos_mask = torch.from_numpy(img_).to(img.device) > 0 + # if num features less than max desired, nothing to do. + features = torch.from_numpy(features).to(img.device) + # generate a map with all pos points + idx = [] + for p in coords: + idx.append(features[round(p[0]), round(p[1]), round(p[2])].item()) + idx = list(set(idx)) + for i in idx: + if i == 0: + continue + outs[0, c] += features == i + outs = outs > 0 + # find negative mean value + fill_in = img[0, c][torch.logical_and(~outs[0, c], not_nan_mask)].mean() + img[0, c][torch.logical_and(pos_mask, ~outs[0, c])] = fill_in + return img + + def convert_points_to_disc( image_size: Sequence[int], point: Tensor, point_label: Tensor, radius: int = 2, disc: bool = False ): @@ -1269,7 +1314,7 @@ def convert_points_to_disc( _array = [ torch.arange(start=0, end=image_size[i], step=1, dtype=torch.float32, device=point.device) for i in range(3) ] - coord_rows, coord_cols, coord_z = torch.meshgrid(_array[2], _array[1], _array[0]) + coord_rows, coord_cols, coord_z = torch.meshgrid(_array[0], _array[1], _array[2]) # [1, 3, h, w, d] -> [b, 2, 3, h, w, d] coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6) coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1) diff --git a/tests/min_tests.py b/tests/min_tests.py index 479c4c8dc2..f80d06f5d3 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -210,6 +210,7 @@ def run_testsuit(): "test_perceptual_loss", "test_ultrasound_confidence_map_transform", "test_vista3d_utils", + "test_vista3d_transforms", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_point_based_window_inferer.py b/tests/test_point_based_window_inferer.py new file mode 100644 index 0000000000..1b293288c4 --- /dev/null +++ b/tests/test_point_based_window_inferer.py @@ -0,0 +1,77 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.apps.vista3d.inferer import point_based_window_inferer +from monai.networks import eval_mode +from monai.networks.nets.vista3d import vista3d132 +from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion, skip_if_quick + +device = "cuda" if torch.cuda.is_available() else "cpu" + +_, has_tqdm = optional_import("tqdm") + +TEST_CASES = [ + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + (1, 1, 64, 64, 64), + { + "roi_size": [32, 32, 32], + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0]], device=device), + }, + ], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + (1, 1, 64, 64, 64), + { + "roi_size": [32, 32, 32], + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0]], device=device), + "class_vector": torch.tensor([1], device=device), + }, + ], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + (1, 1, 64, 64, 64), + { + "roi_size": [32, 32, 32], + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0]], device=device), + "class_vector": torch.tensor([1], device=device), + "point_start": 1, + }, + ], +] + + +@SkipIfBeforePyTorchVersion((1, 11)) +@skip_if_quick +class TestPointBasedWindowInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_vista3d(self, vista3d_params, inputs_shape, inferer_params): + vista3d = vista3d132(**vista3d_params).to(device) + with eval_mode(vista3d): + inferer_params["predictor"] = vista3d + inferer_params["inputs"] = torch.randn(*inputs_shape).to(device) + stitched_output = point_based_window_inferer(**inferer_params) + self.assertEqual(stitched_output.shape, inputs_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vista3d_sampler.py b/tests/test_vista3d_sampler.py new file mode 100644 index 0000000000..6945d250d2 --- /dev/null +++ b/tests/test_vista3d_sampler.py @@ -0,0 +1,100 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.apps.vista3d.sampler import sample_prompt_pairs + +label = torch.zeros([1, 1, 64, 64, 64]) +label[:, :, :10, :10, :10] = 1 +label[:, :, 20:30, 20:30, 20:30] = 2 +label[:, :, 30:40, 30:40, 30:40] = 3 +label1 = torch.zeros([1, 1, 64, 64, 64]) + +TEST_VISTA_SAMPLE_PROMPT = [ + [ + { + "labels": label, + "label_set": [0, 1, 2, 3, 4], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 1, + "drop_label_prob": 0, + "drop_point_prob": 0, + }, + [4, 4, 4, 4], + ], + [ + { + "labels": label, + "label_set": [0, 1], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 1, + "drop_label_prob": 0, + "drop_point_prob": 1, + }, + [2, None, None, 2], + ], + [ + { + "labels": label, + "label_set": [0, 1, 2, 3, 4], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 1, + "drop_label_prob": 1, + "drop_point_prob": 0, + }, + [None, 3, 3, 3], + ], + [ + { + "labels": label1, + "label_set": [0, 1], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 1, + "drop_label_prob": 0, + "drop_point_prob": 1, + }, + [1, None, None, 1], + ], + [ + { + "labels": label1, + "label_set": [0, 1], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 0, + "drop_label_prob": 0, + "drop_point_prob": 1, + }, + [None, None, None, None], + ], +] + + +class TestGeneratePrompt(unittest.TestCase): + @parameterized.expand(TEST_VISTA_SAMPLE_PROMPT) + def test_result(self, input_data, expected): + output = sample_prompt_pairs(**input_data) + result = [i.shape[0] if i is not None else None for i in output] + self.assertEqual(result, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vista3d_transforms.py b/tests/test_vista3d_transforms.py new file mode 100644 index 0000000000..9d61fe2fc2 --- /dev/null +++ b/tests/test_vista3d_transforms.py @@ -0,0 +1,94 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest.case import skipUnless + +import torch +from parameterized import parameterized + +from monai.apps.vista3d.transforms import VistaPostTransformd, VistaPreTransformd +from monai.utils import min_version +from monai.utils.module import optional_import + +measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version) + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +TEST_VISTA_PRETRANSFORM = [ + [ + {"label_prompt": [1], "points": [[0, 0, 0]], "point_labels": [1]}, + {"label_prompt": [1], "points": [[0, 0, 0]], "point_labels": [3]}, + ], + [ + {"label_prompt": [2], "points": [[0, 0, 0]], "point_labels": [0]}, + {"label_prompt": [2], "points": [[0, 0, 0]], "point_labels": [2]}, + ], + [ + {"label_prompt": [3], "points": [[0, 0, 0]], "point_labels": [0]}, + {"label_prompt": [4, 5], "points": [[0, 0, 0]], "point_labels": [0]}, + ], + [ + {"label_prompt": [6], "points": [[0, 0, 0]], "point_labels": [0]}, + {"label_prompt": [7, 8], "points": [[0, 0, 0]], "point_labels": [0]}, + ], +] + + +pred1 = torch.zeros([2, 64, 64, 64]) +pred1[0, :10, :10, :10] = 1 +pred1[1, 20:30, 20:30, 20:30] = 1 +output1 = torch.zeros([1, 64, 64, 64]) +output1[:, :10, :10, :10] = 2 +output1[:, 20:30, 20:30, 20:30] = 3 + +# -1 is needed since pred should be before sigmoid. +pred2 = torch.zeros([1, 64, 64, 64]) - 1 +pred2[:, :10, :10, :10] = 1 +pred2[:, 20:30, 20:30, 20:30] = 1 +output2 = torch.zeros([1, 64, 64, 64]) +output2[:, 20:30, 20:30, 20:30] = 1 + +TEST_VISTA_POSTTRANSFORM = [ + [{"pred": pred1.to(device), "label_prompt": torch.tensor([2, 3]).to(device)}, output1.to(device)], + [ + { + "pred": pred2.to(device), + "points": torch.tensor([[25, 25, 25]]).to(device), + "point_labels": torch.tensor([1]).to(device), + }, + output2.to(device), + ], +] + + +class TestVistaPreTransformd(unittest.TestCase): + @parameterized.expand(TEST_VISTA_PRETRANSFORM) + def test_result(self, input_data, expected): + transform = VistaPreTransformd(keys="image", subclass={"3": [4, 5], "6": [7, 8]}, special_index=[1, 2]) + result = transform(input_data) + self.assertEqual(result, expected) + + +@skipUnless(has_measure, "skimage.measure required") +class TestVistaPostTransformd(unittest.TestCase): + @parameterized.expand(TEST_VISTA_POSTTRANSFORM) + def test_result(self, input_data, expected): + transform = VistaPostTransformd(keys="pred") + result = transform(input_data) + self.assertEqual((result["pred"] == expected).all(), True) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vista3d_utils.py b/tests/test_vista3d_utils.py index a940854d88..5a0caedd61 100644 --- a/tests/test_vista3d_utils.py +++ b/tests/test_vista3d_utils.py @@ -18,11 +18,7 @@ import torch from parameterized import parameterized -from monai.transforms.utils import ( - convert_points_to_disc, - get_largest_connected_component_mask_point, - sample_points_from_label, -) +from monai.transforms.utils import convert_points_to_disc, keep_merge_components_with_points, sample_points_from_label from monai.utils import min_version from monai.utils.module import optional_import from tests.utils import skip_if_no_cuda, skip_if_quick @@ -57,6 +53,31 @@ expected_shape, ] ) + image_size = (16, 32, 64) + point = torch.tensor([[[8, 16, 42], [2, 8, 21]]]) + point_label = torch.tensor([[1, 0]]) + expected_shape = (point.shape[0], 2, *image_size) + TEST_CONVERT_POINTS_TO_DISC.append( + [ + {"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc}, + expected_shape, + ] + ) + +TEST_CONVERT_POINTS_TO_DISC_VALUE = [] +image_size = (16, 32, 64) +point = torch.tensor([[[8, 16, 42], [2, 8, 21]]]) +point_label = torch.tensor([[1, 0]]) +expected_shape = (point.shape[0], 2, *image_size) +for radius in [5, 10]: + for disc in [True, False]: + TEST_CONVERT_POINTS_TO_DISC_VALUE.append( + [ + {"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc}, + [point, point_label], + ] + ) + TEST_LCC_MASK_POINT_TORCH = [] for bs in [1, 2]: @@ -108,9 +129,17 @@ def test_shape(self, input_data, expected_shape): result = convert_points_to_disc(**input_data) self.assertEqual(result.shape, expected_shape) + @parameterized.expand(TEST_CONVERT_POINTS_TO_DISC_VALUE) + def test_value(self, input_data, points): + result = convert_points_to_disc(**input_data) + point, point_label = points + for i in range(point.shape[0]): + for j in range(point.shape[1]): + self.assertEqual(result[i, point_label[i, j], point[i, j][0], point[i, j][1], point[i, j][2]], True) + @skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required") -class TestGetLargestConnectedComponentMaskPoint(unittest.TestCase): +class TestKeepMergeComponentsWithPoints(unittest.TestCase): @skip_if_quick @skip_if_no_cuda @@ -119,13 +148,13 @@ class TestGetLargestConnectedComponentMaskPoint(unittest.TestCase): def test_cp_shape(self, input_data, shape): for key in input_data: input_data[key] = input_data[key].to(device) - mask = get_largest_connected_component_mask_point(**input_data) + mask = keep_merge_components_with_points(**input_data) self.assertEqual(mask.shape, shape) @skipUnless(has_measure, "skimage required") @parameterized.expand(TEST_LCC_MASK_POINT_NP) def test_np_shape(self, input_data, shape): - mask = get_largest_connected_component_mask_point(**input_data) + mask = keep_merge_components_with_points(**input_data) self.assertEqual(mask.shape, shape)