From 129a3001f015d7be4903268310b048ee57b46193 Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Fri, 2 Aug 2024 15:00:39 -0400 Subject: [PATCH 01/30] Add vista network Signed-off-by: heyufan1995 --- monai/networks/nets/segresnet_ds.py | 181 ++++++ monai/networks/nets/vista3d.py | 952 ++++++++++++++++++++++++++++ monai/transforms/utils.py | 246 +++++++ 3 files changed, 1379 insertions(+) create mode 100644 monai/networks/nets/vista3d.py diff --git a/monai/networks/nets/segresnet_ds.py b/monai/networks/nets/segresnet_ds.py index 6430f5fdc9..1041c07259 100644 --- a/monai/networks/nets/segresnet_ds.py +++ b/monai/networks/nets/segresnet_ds.py @@ -425,3 +425,184 @@ def _forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tens def forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tensor]]: return self._forward(x) + + +class SegResNetDS2(SegResNetDS): + """ + SegResNetDS2 is the image encoder used by VISTA3D. It adds one additional decoder branch. + """ + def __init__( + self, + spatial_dims: int = 3, + init_filters: int = 32, + in_channels: int = 1, + out_channels: int = 2, + act: tuple | str = "relu", + norm: tuple | str = "batch", + blocks_down: tuple = (1, 2, 2, 4), + blocks_up: tuple | None = None, + dsdepth: int = 1, + preprocess: nn.Module | Callable | None = None, + upsample_mode: UpsampleMode | str = "deconv", + resolution: tuple | None = None, + ): + super().__init__( + spatial_dims = spatial_dims, + init_filters=init_filters, + in_channels=in_channels, + out_channels= out_channels, + act = act, + norm = norm, + blocks_down = blocks_down, + blocks_up = blocks_up, + dsdepth = dsdepth, + preprocess = preprocess, + upsample_mode = upsample_mode, + resolution = resolution) + + if spatial_dims not in (1, 2, 3): + raise ValueError("`spatial_dims` can only be 1, 2 or 3.") + + if resolution is not None: + if not isinstance(resolution, (list, tuple)): + raise TypeError("resolution must be a tuple") + elif not all(r > 0 for r in resolution): + raise ValueError("resolution must be positive") + + # ensure normalization had affine trainable parameters (if not specified) + norm = split_args(norm) + if has_option(Norm[norm[0], spatial_dims], "affine"): + norm[1].setdefault("affine", True) # type: ignore + + # ensure activation is inplace (if not specified) + act = split_args(act) + if has_option(Act[act[0]], "inplace"): + act[1].setdefault("inplace", True) # type: ignore + + n_up = len(blocks_down) - 1 + + filters = init_filters * 2**n_up + self.up_layers_auto = nn.ModuleList() + + # self.anisotropic_scales and self.blocks_up are created within super().init() + + for i in range(n_up): + filters = filters // 2 + kernel_size, _, stride = ( + aniso_kernel(self.anisotropic_scales[len(self.blocks_up) - i - 1]) + if self.anisotropic_scales + else (3, 1, 2) + ) + + level_auto = nn.ModuleDict() + blocks = [ + SegResBlock( + spatial_dims=spatial_dims, + in_channels=filters, + kernel_size=kernel_size, + norm=norm, + act=act, + ) + for _ in range(self.blocks_up[i]) + ] + level_auto["blocks"] = nn.Sequential(*blocks) + if len(self.blocks_up) - i <= dsdepth: # deep supervision heads + level_auto["head"] = Conv[Conv.CONV, spatial_dims]( + in_channels=filters, + out_channels=out_channels, + kernel_size=1, + bias=True, + ) + else: + level_auto["head"] = nn.Identity() + self.up_layers_auto.append(level_auto) + + if ( + n_up == 0 + ): # in a corner case of flat structure (no downsampling), attache a single head + level_auto = nn.ModuleDict( + { + "upsample": nn.Identity(), + "blocks": nn.Identity(), + "head": Conv[Conv.CONV, spatial_dims]( + in_channels=filters, + out_channels=out_channels, + kernel_size=1, + bias=True, + ), + } + ) + self.up_layers_auto.append(level_auto) + + def _forward( + self, x: torch.Tensor, with_point, with_label + ) -> Union[None, torch.Tensor, list[torch.Tensor]]: + if self.preprocess is not None: + x = self.preprocess(x) + + if not self.is_valid_shape(x): + raise ValueError( + f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}" + ) + + x_down = self.encoder(x) + + x_down.reverse() + x = x_down.pop(0) + + if len(x_down) == 0: + x_down = [torch.zeros(1, device=x.device, dtype=x.dtype)] + + outputs: list[torch.Tensor] = [] + outputs_auto: list[torch.Tensor] = [] + x_ = x.clone() + if with_point: + i = 0 + for level in self.up_layers: + x = level["upsample"](x) + x = x + x_down[i] + x = level["blocks"](x) + + if len(self.up_layers) - i <= self.dsdepth: + outputs.append(level["head"](x)) + i = i + 1 + + outputs.reverse() + x = x_ + if with_label: + i = 0 + for level in self.up_layers_auto: + x = level["upsample"](x) + x = x + x_down[i] + x = level["blocks"](x) + + if len(self.up_layers) - i <= self.dsdepth: + outputs_auto.append(level["head"](x)) + i = i + 1 + + outputs_auto.reverse() + + # in eval() mode, always return a single final output + if not self.training or len(outputs) == 1: + outputs = outputs[0] if len(outputs) == 1 else outputs + + if not self.training or len(outputs_auto) == 1: + outputs_auto = outputs_auto[0] if len(outputs_auto) == 1 else outputs_auto + + # return a list of DS outputs + return outputs, outputs_auto + + def forward( + self, x: torch.Tensor, with_point=True, with_label=True, **kwargs + ) -> Union[None, torch.Tensor, list[torch.Tensor]]: + return self._forward(x, with_point, with_label) + + def set_auto_grad(self, auto_freeze=False, point_freeze=False): + for param in self.encoder.parameters(): + param.requires_grad = (not auto_freeze) and (not point_freeze) + + for param in self.up_layers_auto.parameters(): + param.requires_grad = not auto_freeze + + for param in self.up_layers.parameters(): + param.requires_grad = not point_freeze diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py new file mode 100644 index 0000000000..9237f0d99d --- /dev/null +++ b/monai/networks/nets/vista3d.py @@ -0,0 +1,952 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from typing import Any, Optional, Tuple, Type + +from torch import Tensor, nn +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import monai +from monai.networks.blocks import UnetrBasicBlock +from monai.transforms.utils import get_largest_connected_component_mask_point as lcc +from monai.transforms.utils import convert_points_to_disc, sample_points_from_label + +from scripts.utils.workflow_utils import sample_points_patch_val + + +rearrange, _ = optional_import("einops", name="rearrange") + +__all__ = ["VISTA3D"] + + + + +class VISTA3D(nn.Module): + """ + VISTA3D based on `VISTA3D: Versatile Imaging SegmenTation and Annotation model for 3D Computed Tomography + https://arxiv.org/abs/2406.05285>`_. + Args: + image_encoder: image encoder backbone for feature extraction. + class_head: class head used for class index based segmentation + point_head: point head used for interactive segmetnation + """ + def __init__(self, image_encoder, class_head, point_head): + super().__init__() + self.image_encoder = image_encoder + self.class_head = class_head + self.point_head = point_head + self.image_embeddings = None + self.auto_freeze = False + self.point_freeze = False + self.NINF_VALUE = -9999 + self.PINF_VALUE = 9999 + + def get_bs(self, class_vector, point_coords): + if class_vector is None: + assert point_coords is not None, "prompt is required" + return point_coords.shape[0] + else: + return class_vector.shape[0] + + def convert_point_label(self, point_label, label_set=None, + special_index=[23, 24, 25, 26, 27, 57, 128]): + if label_set is None: + return point_label + assert point_label.shape[0] == len(label_set) + for i in range(len(label_set)): + if label_set[i] in special_index: + for j in range(len(point_label[i])): + point_label[i, j] = ( + point_label[i, j] + 2 + if point_label[i, j] > -1 + else point_label[i, j] + ) + return point_label + + def sample_points_patch_val( + self, + labels, + patch_coords, + label_set, + use_center=True, + mapped_label_set=None, + max_ppoint=1, + max_npoint=0, + **kwargs + ): + """Sample points for patch during sliding window validation. Only used for point only validation. + Args: + labels: [1, 1, H, W, D] + patch_coords: sliding window slice object + label_set: local index, must match values in labels + use_center: sample points from the center + mapped_label_set: global index, it is used to identify special classes. + max_ppoint/max_npoint: positive points and negative points to sample. + """ + point_coords, point_labels = sample_points_from_label( + labels[patch_coords], + label_set, + max_ppoint=max_ppoint, + max_npoint=max_npoint, + device=labels.device, + use_center=use_center, + ) + point_labels = self.convert_point_label(point_labels, mapped_label_set) + return ( + point_coords, + point_labels, + torch.tensor(mapped_label_set).to(point_coords.device).unsqueeze(-1), + ) + + def update_point_to_patch(self, patch_coords, point_coords, point_labels): + """ Update point_coords with respect to patch coords. + If point is outside of the patch, remove the coordinates and set label to -1 + """ + patch_ends = [ + patch_coords[-3].stop, + patch_coords[-2].stop, + patch_coords[-1].stop, + ] + patch_starts = [ + patch_coords[-3].start, + patch_coords[-2].start, + patch_coords[-1].start, + ] + # update point coords + patch_starts = ( + torch.tensor(patch_starts, device=point_coords.device) + .unsqueeze(0) + .unsqueeze(0) + ) + patch_ends = ( + torch.tensor(patch_ends, device=point_coords.device) + .unsqueeze(0) + .unsqueeze(0) + ) + # [1 N 1] + indices = torch.logical_and( + ((point_coords - patch_starts) > 0).all(2), + ((patch_ends - point_coords) > 0).all(2), + ) + # check if it's within patch coords + point_coords = point_coords.clone() - patch_starts + point_labels = point_labels.clone() + if indices.any(): + point_labels[~indices] = -1 + point_coords[~indices] = 0 + # also remove padded points, mainly used for inference. + not_pad_indices = (point_labels != -1).any(0) + point_coords = point_coords[:, not_pad_indices] + point_labels = point_labels[:, not_pad_indices] + else: + point_coords = None + point_labels = None + return point_coords, point_labels + + def connected_components_combine( + self, logits, point_logits, point_coords, point_labels, mapping_index, thred=0.5 + ): + """Combine auto results with point click response, or combine previous mask with point click response. + For mapping_index with point clicks, NaN values in logits will be replaced with point_logits. Meanwhile, the added/removed + region in point clicks must be updated by the lcc function. Notice, if a positive point is within logits/prev_mask, the components containing the positive point + will be added. + """ + logits = ( + logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits + ) + _logits = logits[mapping_index] + inside = [] + for i in range(_logits.shape[0]): + inside.append( + np.any( + [ + _logits[ + i, + 0, + round(p[0].item()), + round(p[1].item()), + round(p[2].item()), + ].item() + > 0 + for p in point_coords[i] + ] + ) + ) + inside = torch.tensor(inside).to(logits.device) + nan_mask = torch.isnan(_logits) + _logits = torch.nan_to_num(_logits, nan=self.NINF_VALUE).sigmoid() + pos_region = point_logits.sigmoid() > thred + diff_pos = torch.logical_and( + torch.logical_or( + (_logits <= thred), + inside.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1), + ), + pos_region, + ) + diff_neg = torch.logical_and((_logits > thred), ~pos_region) + cc = lcc( + diff_pos, diff_neg, point_coords=point_coords, point_labels=point_labels + ) + # cc is the region that can be updated by point_logits. + cc = cc.to(logits.device) + # Need to replace NaN with point_logits. diff_neg will never lie in nan_mask, only remove unconnected positive region. + uc_pos_region = torch.logical_and(pos_region, ~cc) + fill_mask = torch.logical_and(nan_mask, uc_pos_region) + if fill_mask.any(): + # fill in the mean negative value + point_logits[fill_mask] = -1 + # replace logits nan value and cc with point_logits + cc = torch.logical_or(nan_mask, cc).to(logits.dtype) + logits[mapping_index] *= 1 - cc + logits[mapping_index] += cc * point_logits + # debug_ccp(_logits, point_logits.sigmoid(), point_coords, point_labels, diff, cc, logits[mapping_index], np.random.randint(10000)) + return logits + + def gaussian_combine( + self, logits, point_logits, point_coords, point_labels, mapping_index, radius + ): + """Combine point results with auto results using gaussian.""" + if radius is None: + radius = min(point_logits.shape[-3:]) // 5 # empirical value 5 + weight = 1 - convert_points_to_disc( + point_logits.shape[-3:], point_coords, point_labels, radius=radius + ).sum(1, keepdims=True) + weight[weight < 0] = 0 + logits = ( + logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits + ) + logits[mapping_index] *= weight + logits[mapping_index] += (1 - weight) * point_logits + return logits + + def set_auto_grad(self, auto_freeze=False, point_freeze=False): + """Freeze auto-branch or point-branch""" + if auto_freeze != self.auto_freeze: + if hasattr(self.image_encoder, "set_auto_grad"): + self.image_encoder.set_auto_grad( + auto_freeze=auto_freeze, point_freeze=point_freeze + ) + else: + for param in self.image_encoder.parameters(): + param.requires_grad = (not auto_freeze) and (not point_freeze) + for param in self.class_head.parameters(): + param.requires_grad = not auto_freeze + self.auto_freeze = auto_freeze + + if point_freeze != self.point_freeze: + if hasattr(self.image_encoder, "set_auto_grad"): + self.image_encoder.set_auto_grad( + auto_freeze=auto_freeze, point_freeze=point_freeze + ) + else: + for param in self.image_encoder.parameters(): + param.requires_grad = (not auto_freeze) and (not point_freeze) + for param in self.point_head.parameters(): + param.requires_grad = not point_freeze + self.point_freeze = point_freeze + + def forward( + self, + input_images, + point_coords=None, + point_labels=None, + class_vector=None, + prompt_class=None, + patch_coords=None, + labels=None, + label_set=None, + prev_mask=None, + radius=None, + val_point_sampler=None, + **kwargs, + ): + """ + The forward function for VISTA3D. We only support single patch in training and inference. + One exception is allowing sliding window batch size > 1 for automatic segmentation only case. + B represents number of objects, N represents number of points for each objects. + Args: + input_images: [1, 1, H, W, D] + point_coords: [B, N, 3] + point_labels: [B, N], -1 represents padding. 0/1 means negative/positive points for regular class. + 2/3 means negative/postive ponits for special supported class like tumor. + class_vector: [B, 1], the global class index + prompt_class: [B, 1], the global class index. This value is associated with point_coords to identify if + the points are for zero-shot or supported class. When class_vector and point_coords are both + provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b] + will be considered novel class. + patch_coords: the python slice object representing the patch coordinates during sliding window inference. This value is + passed from monai_utils.sliding_window_inferer. This is an indicator for training phase or validation phase. + labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation + label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID, + this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot + evaluation, this label_set should be the original index. + prev_mask: [B, N, H_fullsize, W_fullsize, D_fullsize]. This is the transposed raw output from sliding_window_inferer before + any postprocessing. When user click points to perform auto-results correction, this can be the auto-results. + radius: single float value controling the gaussian blur when combining point and auto results. The gaussian combine is not used + in VISTA3D training but might be useful for finetuning purposes. + val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation. + + """ + image_size = input_images.shape[-3:] + device = input_images.device + if point_coords is None and class_vector is None: + return NINF_VALUE + torch.zeros([1, 1, *image_size], device=device) + + bs = self.get_bs(class_vector, point_coords) + if patch_coords is not None: + # if during validation and perform enable based point-validation. + if labels is not None and label_set is not None: + # if labels is not None, sample from labels for each patch. + if val_point_sampler is None: + val_point_sampler = sample_points_patch_val + point_coords, point_labels, prompt_class = val_point_sampler( + labels, patch_coords, label_set + ) + if prompt_class[0].item() == 0: + point_labels[0] = -1 + labels, prev_mask = None, None + elif point_coords is not None: + # If not performing patch-based point only validation, use user provided click points for inference. + # the point clicks is in original image space, convert it to current patch-coordinate space. + point_coords, point_labels = self.update_point_to_patch( + patch_coords, point_coords, point_labels + ) + + if point_coords is not None and point_labels is not None: + # remove points that used for padding purposes (point_label = -1) + mapping_index = ((point_labels != -1).sum(1) > 0).to(torch.bool) + if mapping_index.any(): + point_coords = point_coords[mapping_index] + point_labels = point_labels[mapping_index] + if prompt_class is not None: + prompt_class = prompt_class[mapping_index] + else: + if self.auto_freeze or (class_vector is None and patch_coords is None): + # if auto_freeze, point prompt must exist to allow loss backward + # in training, class_vector and point cannot both be None due to loss.backward() + mapping_index.fill_(True) + else: + point_coords, point_labels = None, None + + if point_coords is None and class_vector is None: + return self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device) + + if ( + self.image_embeddings is not None + and kwargs.get("keep_cache", False) + and class_vector is None + ): + out, out_auto = self.image_embeddings, None + else: + out, out_auto = self.image_encoder( + input_images, + with_point=point_coords is not None, + with_label=class_vector is not None, + ) + input_images = None + + # force releasing memories that set to None + torch.cuda.empty_cache() + if class_vector is not None: + logits, _ = self.class_head(out_auto, class_vector) + if point_coords is not None: + point_logits = self.point_head( + out, point_coords, point_labels, class_vector=prompt_class + ) + if patch_coords is None: + logits = self.gaussian_combine( + logits, + point_logits, + point_coords, + point_labels, + mapping_index, + radius, + ) + else: + # during validation use largest component + logits = self.connected_components_combine( + logits, point_logits, point_coords, point_labels, mapping_index + ) + else: + logits = self.NINF_VALUE + torch.zeros( + [bs, 1, *image_size], device=device, dtype=out.dtype + ) + logits[mapping_index] = self.point_head( + out, point_coords, point_labels, class_vector=prompt_class + ) + if prev_mask is not None and patch_coords is not None: + logits = self.connected_components_combine( + prev_mask[patch_coords].transpose(1, 0).to(logits.device), + logits[mapping_index], + point_coords, + point_labels, + mapping_index, + ) + + if kwargs.get("keep_cache", False) and class_vector is None: + self.image_embeddings = out.detach() + return logits + +class Point_Mapping_SAM(nn.Module): + def __init__( + self, + feature_size, + max_prompt=32, + num_add_mask_tokens=2, + n_classes=512, + last_supported=132, + ): + super().__init__() + transformer_dim = feature_size + self.max_prompt = max_prompt + self.feat_downsample = nn.Sequential( + nn.Conv3d( + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=2, + padding=1, + ), + nn.InstanceNorm3d(feature_size), + nn.GELU(), + nn.Conv3d( + in_channels=feature_size, + out_channels=transformer_dim, + kernel_size=3, + stride=1, + padding=1, + ), + nn.InstanceNorm3d(feature_size), + ) + + self.mask_downsample = nn.Conv3d( + in_channels=2, out_channels=2, kernel_size=3, stride=2, padding=1 + ) + + self.transformer = TwoWayTransformer( + depth=2, + embedding_dim=transformer_dim, + mlp_dim=512, + num_heads=4, + ) + self.pe_layer = PositionEmbeddingRandom(transformer_dim // 2) + self.point_embeddings = nn.ModuleList( + [nn.Embedding(1, transformer_dim), nn.Embedding(1, transformer_dim)] + ) + self.not_a_point_embed = nn.Embedding(1, transformer_dim) + self.special_class_embed = nn.Embedding(1, transformer_dim) + self.mask_tokens = nn.Embedding(1, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose3d( + transformer_dim, + transformer_dim, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + ), + nn.InstanceNorm3d(transformer_dim), + nn.GELU(), + nn.Conv3d( + transformer_dim, transformer_dim, kernel_size=3, stride=1, padding=1 + ), + ) + + self.output_hypernetworks_mlps = MLP( + transformer_dim, transformer_dim, transformer_dim, 3 + ) + + ## MultiMask output + self.num_add_mask_tokens = num_add_mask_tokens + self.output_add_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim, 3) + for i in range(self.num_add_mask_tokens) + ] + ) + # class embedding + self.n_classes = n_classes + self.last_supported = last_supported + self.class_embeddings = nn.Embedding(n_classes, feature_size) + self.zeroshot_embed = nn.Embedding(1, transformer_dim) + self.supported_embed = nn.Embedding(1, transformer_dim) + + def forward(self, out, point_coords, point_labels, class_vector=None): + # downsample out + out_low = self.feat_downsample(out) + out_shape = out.shape[-3:] + out = None + torch.cuda.empty_cache() + # embed points + points = point_coords + 0.5 # Shift to center of pixel + point_embedding = self.pe_layer.forward_with_coords(points, out_shape) + point_embedding[point_labels == -1] = 0.0 + point_embedding[point_labels == -1] += self.not_a_point_embed.weight + point_embedding[point_labels == 0] += self.point_embeddings[0].weight + point_embedding[point_labels == 1] += self.point_embeddings[1].weight + point_embedding[point_labels == 2] += ( + self.point_embeddings[0].weight + self.special_class_embed.weight + ) + point_embedding[point_labels == 3] += ( + self.point_embeddings[1].weight + self.special_class_embed.weight + ) + output_tokens = self.mask_tokens.weight + + output_tokens = output_tokens.unsqueeze(0).expand( + point_embedding.size(0), -1, -1 + ) + if class_vector is None: + tokens_all = torch.cat( + ( + output_tokens, + point_embedding, + self.supported_embed.weight.unsqueeze(0).expand( + point_embedding.size(0), -1, -1 + ), + ), + dim=1, + ) + # tokens_all = torch.cat((output_tokens, point_embedding), dim=1) + else: + class_embeddings = [] + for i in class_vector: + if i > self.last_supported: + class_embeddings.append(self.zeroshot_embed.weight) + else: + class_embeddings.append(self.supported_embed.weight) + class_embeddings = torch.stack(class_embeddings) + tokens_all = torch.cat( + (output_tokens, point_embedding, class_embeddings), dim=1 + ) + # cross attention + masks = [] + max_prompt = self.max_prompt + for i in range(int(np.ceil(tokens_all.shape[0] / max_prompt))): + # remove variables in previous for loops to save peak memory for self.transformer + src, upscaled_embedding, hyper_in = None, None, None + torch.cuda.empty_cache() + idx = (i * max_prompt, min((i + 1) * max_prompt, tokens_all.shape[0])) + tokens = tokens_all[idx[0] : idx[1]] + src = torch.repeat_interleave(out_low, tokens.shape[0], dim=0) + pos_src = torch.repeat_interleave( + self.pe_layer(out_low.shape[-3:]).unsqueeze(0), tokens.shape[0], dim=0 + ) + b, c, h, w, d = src.shape + hs, src = self.transformer(src, pos_src, tokens) + mask_tokens_out = hs[:, :1, :] + hyper_in = self.output_hypernetworks_mlps(mask_tokens_out) + src = src.transpose(1, 2).view(b, c, h, w, d) + upscaled_embedding = self.output_upscaling(src) + b, c, h, w, d = upscaled_embedding.shape + masks.append( + (hyper_in @ upscaled_embedding.view(b, c, h * w * d)).view( + b, -1, h, w, d + ) + ) + masks = torch.vstack(masks) + return masks + + +class Class_Mapping_Classify(nn.Module): + def __init__(self, n_classes, feature_size, use_mlp=False): + super().__init__() + self.use_mlp = use_mlp + if use_mlp: + self.mlp = nn.Sequential( + nn.Linear(feature_size, feature_size), + nn.InstanceNorm1d(1), + nn.GELU(), + nn.Linear(feature_size, feature_size), + ) + self.class_embeddings = nn.Embedding(n_classes, feature_size) + self.image_post_mapping = nn.Sequential( + UnetrBasicBlock( + spatial_dims=3, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name="instance", + res_block=True, + ), + UnetrBasicBlock( + spatial_dims=3, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name="instance", + res_block=True, + ), + ) + + def forward(self, src, class_vector): + b, c, h, w, d = src.shape + src = self.image_post_mapping(src) + class_embedding = self.class_embeddings(class_vector) + if self.use_mlp: + class_embedding = self.mlp(class_embedding) + # [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension. + masks = [] + for i in range(b): + mask = (class_embedding @ src[[i]].view(1, c, h * w * d)).view( + -1, 1, h, w, d + ) + masks.append(mask) + masks = torch.cat(masks, 1) + return masks, class_embedding + + +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from https://github.com/facebookresearch/segment-anything +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w, d = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((3, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + # [bs=1,N=2,2] @ [2,128] + # [bs=1, N=2, 128] + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + # [bs=1, N=2, 128+128=256] + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w, d = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w, d), device=device, dtype=torch.float32) + x_embed = grid.cumsum(dim=0) - 0.5 + y_embed = grid.cumsum(dim=1) - 0.5 + z_embed = grid.cumsum(dim=2) - 0.5 + x_embed = x_embed / h + y_embed = y_embed / w + z_embed = z_embed / d + pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1)) + return pe.permute(3, 0, 1, 2) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[0] + coords[:, :, 1] = coords[:, :, 1] / image_size[1] + coords[:, :, 2] = coords[:, :, 2] / image_size[2] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index d8461d927b..f5dd2f7a88 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -22,6 +22,7 @@ import numpy as np import torch +import torch.nn.functional as F import monai from monai.config import DtypeLike, IndexSelection @@ -103,6 +104,10 @@ "generate_spatial_bounding_box", "get_extreme_points", "get_largest_connected_component_mask", + "get_largest_connected_component_mask_point", + "sample_points_from_label", + "erode3d", + "sample" "remove_small_objects", "img_bounds", "in_bounds", @@ -1171,6 +1176,247 @@ def get_largest_connected_component_mask( return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0] +def get_largest_connected_component_mask_point( + img_pos: NdarrayTensor, + img_neg: NdarrayTensor, + pos_val: list=[1, 3], + neg_val: list=[0, 2], + point_coords: None = None, + point_labels: None = None, + margins: int = 3, +) -> NdarrayTensor: + """ + Gets the largest connected component mask of an image that include the point_coords. + Args: + img_pos: [1, B, H, W, D] + point_coords [B, N, 3] + point_labels [B, N] + """ + + img_pos_, *_ = convert_data_type(img_pos, np.ndarray) + img_neg_, *_ = convert_data_type(img_neg, np.ndarray) + label = measure.label + lib = np + + features_pos, num_features = label(img_pos_, connectivity=3, return_num=True) + features_neg, num_features = label(img_neg_, connectivity=3, return_num=True) + + outs = np.zeros_like(img_pos_) + for bs in range(point_coords.shape[0]): + for i, p in enumerate(point_coords[bs]): + if point_labels[bs, i] in pos_val: + features = features_pos + elif point_labels[bs, i] in neg_val: + features = features_neg + else: + # if -1 padding point, skip + continue + for margin in range(margins): + left, right = max(p[0].round().int().item() - margin, 0), min( + p[0].round().int().item() + margin + 1, features.shape[-3] + ) + t, d = max(p[1].round().int().item() - margin, 0), min( + p[1].round().int().item() + margin + 1, features.shape[-2] + ) + f, b = max(p[2].round().int().item() - margin, 0), min( + p[2].round().int().item() + margin + 1, features.shape[-1] + ) + if (features[bs, 0, left:right, t:d, f:b] > 0).any(): + index = features[bs, 0, left:right, t:d, f:b].max() + outs[[bs]] += lib.isin(features[[bs]], index) + break + outs[outs > 1] = 1 + return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0] + +def convert_points_to_disc(image_size, point, point_label, radius=2, disc=False): + """ + Convert a 3D point coordinates into image mask. The returned mask has the same spatial + size as `image_size` while the batch dimension is the same as point' batch dimension. + The point is converted to a mask ball with radius defined by `radius`. + Args: + image_size: The output size of th + point: [b, N, 3] + point_label: [b, N] + radius: disc ball radius size + disc: If true, use regular disc other other use gaussian. + """ + if not torch.is_tensor(point): + point = torch.from_numpy(point) + masks = torch.zeros( + [point.shape[0], 2, image_size[0], image_size[1], image_size[2]], + device=point.device, + ) + row_array = torch.arange( + start=0, end=image_size[0], step=1, dtype=torch.float32, device=point.device + ) + col_array = torch.arange( + start=0, end=image_size[1], step=1, dtype=torch.float32, device=point.device + ) + z_array = torch.arange( + start=0, end=image_size[2], step=1, dtype=torch.float32, device=point.device + ) + coord_rows, coord_cols, coord_z = torch.meshgrid(z_array, col_array, row_array) + # [1,3,h,w,d] -> [b, 2, 3, h,w,d] + coords = ( + torch.stack((coord_rows, coord_cols, coord_z), dim=0) + .unsqueeze(0) + .unsqueeze(0) + .repeat(point.shape[0], 2, 1, 1, 1, 1) + ) + for b in range(point.shape[0]): + for n in range(point.shape[1]): + if point_label[b, n] > -1: + channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1 + if disc: + masks[b, channel] += ( + torch.pow( + coords[b, channel] + - point[b, n].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1), + 2, + ).sum(0) + < radius**2 + ) + else: + masks[b, channel] += torch.exp( + -torch.pow( + coords[b, channel] + - point[b, n].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1), + 2, + ).sum(0) + / (2 * radius**2) + ) + return masks + +def sample_points_from_label( + labels, label_set=None, max_ppoint=1, max_npoint=0, device="cpu", use_center=False +): + """Sample points from labels. + Args: + labels: [1, 1, H, W, D] + label_set: local index, must match values in labels. + max_ppoint: maximum positive point samples. + max_npoint: maximum negative point samples. + device: returned tensor device. + use_center: whether to sample points from center. + Returns: + point: point coordinates of [B, N, 3]. + point_label: [B, N], always 0 for negative, 1 for positive. + """ + assert labels.shape[0] == 1, "only support batch size 1" + labels = labels[0, 0] + unique_labels = labels.unique().cpu().numpy().tolist() + _point = [] + _point_label = [] + Nn = max_npoint + Np = max_ppoint + for id in label_set: + if id in unique_labels: + plabels = labels == int(id) + nlabels = ~plabels + _plabels = get_largest_connected_component_mask(erode3d(plabels)) + plabelpoints = torch.nonzero(_plabels).to(device) + if len(plabelpoints) == 0: + plabelpoints = torch.nonzero(plabels).to(device) + nlabelpoints = torch.nonzero(nlabels).to(device) + if use_center: + pmean = plabelpoints.float().mean(0) + pdis = ((plabelpoints - pmean) ** 2).sum(-1) + _, sorted_indices = torch.sort(pdis) + _point.append( + torch.stack( + [ + plabelpoints[sorted_indices[i]] + for i in range(min(len(plabelpoints), Np)) + ] + + random.choices(nlabelpoints, k=min(len(nlabelpoints), Nn)) + + [torch.tensor([0, 0, 0], device=device)] + * ( + Np + + Nn + - min(len(plabelpoints), Np) + - min(len(nlabelpoints), Nn) + ) + ) + ) + _point_label.append( + torch.tensor( + [1] * min(len(plabelpoints), Np) + + [0.0] * min(len(nlabelpoints), Nn) + + [-1] + * ( + Np + + Nn + - min(len(plabelpoints), Np) + - min(len(nlabelpoints), Nn) + ) + ).to(device) + ) + + else: + _point.append( + torch.stack( + random.choices(plabelpoints, k=min(len(plabelpoints), Np)) + + random.choices(nlabelpoints, k=min(len(nlabelpoints), Nn)) + + [torch.tensor([0, 0, 0], device=device)] + * ( + Np + + Nn + - min(len(plabelpoints), Np) + - min(len(nlabelpoints), Nn) + ) + ) + ) + _point_label.append( + torch.tensor( + [1] * min(len(plabelpoints), Np) + + [0.0] * min(len(nlabelpoints), Nn) + + [-1] + * ( + Np + + Nn + - min(len(plabelpoints), Np) + - min(len(nlabelpoints), Nn) + ) + ).to(device) + ) + else: + # pad the background labels + _point.append(torch.zeros(Np + Nn, 3).to(device)) # all 0 + _point_label.append(torch.zeros(Np + Nn).to(device) - 1) # -1 not a point + point = torch.stack(_point) + point_label = torch.stack(_point_label) + return point, point_label + +def erode3d(input_tensor, erosion=3): + # Define the structuring element + erosion = ensure_tuple_rep(erosion, 3) + structuring_element = torch.ones(1, 1, erosion[0], erosion[1], erosion[2]).to( + input_tensor.device + ) + + # Pad the input tensor to handle border pixels + input_padded = F.pad( + input_tensor.float().unsqueeze(0).unsqueeze(0), + ( + erosion[2] // 2, + erosion[2] // 2, + erosion[1] // 2, + erosion[1] // 2, + erosion[0] // 2, + erosion[0] // 2, + ), + mode="constant", + value=1.0, + ) + + # Apply erosion operation + output = F.conv3d(input_padded, structuring_element, padding=0) + + # Set output values based on the minimum value within the structuring element + output = torch.where(output == torch.sum(structuring_element), 1.0, 0.0) + + return output.squeeze(0).squeeze(0) + def remove_small_objects( img: NdarrayTensor, From 4899b14c760813385974356c4a73c25c420e1635 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 2 Aug 2024 19:03:20 +0000 Subject: [PATCH 02/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/segresnet_ds.py | 8 ++++---- monai/networks/nets/vista3d.py | 7 +++---- monai/transforms/utils.py | 4 ++-- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/monai/networks/nets/segresnet_ds.py b/monai/networks/nets/segresnet_ds.py index 1041c07259..e143eb57da 100644 --- a/monai/networks/nets/segresnet_ds.py +++ b/monai/networks/nets/segresnet_ds.py @@ -429,8 +429,8 @@ def forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tenso class SegResNetDS2(SegResNetDS): """ - SegResNetDS2 is the image encoder used by VISTA3D. It adds one additional decoder branch. - """ + SegResNetDS2 is the image encoder used by VISTA3D. It adds one additional decoder branch. + """ def __init__( self, spatial_dims: int = 3, @@ -459,7 +459,7 @@ def __init__( preprocess = preprocess, upsample_mode = upsample_mode, resolution = resolution) - + if spatial_dims not in (1, 2, 3): raise ValueError("`spatial_dims` can only be 1, 2 or 3.") @@ -533,7 +533,7 @@ def __init__( } ) self.up_layers_auto.append(level_auto) - + def _forward( self, x: torch.Tensor, with_point, with_label ) -> Union[None, torch.Tensor, list[torch.Tensor]]: diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 9237f0d99d..64ff066733 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -17,7 +17,6 @@ from torch import Tensor, nn import numpy as np import torch -import torch.nn as nn import torch.nn.functional as F import monai @@ -61,8 +60,8 @@ def get_bs(self, class_vector, point_coords): return point_coords.shape[0] else: return class_vector.shape[0] - - def convert_point_label(self, point_label, label_set=None, + + def convert_point_label(self, point_label, label_set=None, special_index=[23, 24, 25, 26, 27, 57, 128]): if label_set is None: return point_label @@ -560,7 +559,7 @@ def forward(self, out, point_coords, point_labels, class_vector=None): ) masks = torch.vstack(masks) return masks - + class Class_Mapping_Classify(nn.Module): def __init__(self, n_classes, feature_size, use_mlp=False): diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index f5dd2f7a88..05706984b8 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1231,14 +1231,14 @@ def get_largest_connected_component_mask_point( def convert_points_to_disc(image_size, point, point_label, radius=2, disc=False): """ Convert a 3D point coordinates into image mask. The returned mask has the same spatial - size as `image_size` while the batch dimension is the same as point' batch dimension. + size as `image_size` while the batch dimension is the same as point' batch dimension. The point is converted to a mask ball with radius defined by `radius`. Args: image_size: The output size of th point: [b, N, 3] point_label: [b, N] radius: disc ball radius size - disc: If true, use regular disc other other use gaussian. + disc: If true, use regular disc other other use gaussian. """ if not torch.is_tensor(point): point = torch.from_numpy(point) From 66408f51fff7ee1990c78b150c12344dc757e85e Mon Sep 17 00:00:00 2001 From: Yufan He Date: Mon, 5 Aug 2024 15:36:13 -0400 Subject: [PATCH 03/30] Fix comments Signed-off-by: Yufan He --- monai/networks/nets/segresnet_ds.py | 18 ++-- monai/networks/nets/vista3d.py | 50 +++------- monai/transforms/utils.py | 139 +++++++--------------------- 3 files changed, 56 insertions(+), 151 deletions(-) diff --git a/monai/networks/nets/segresnet_ds.py b/monai/networks/nets/segresnet_ds.py index e143eb57da..be8d90c49c 100644 --- a/monai/networks/nets/segresnet_ds.py +++ b/monai/networks/nets/segresnet_ds.py @@ -460,15 +460,6 @@ def __init__( upsample_mode = upsample_mode, resolution = resolution) - if spatial_dims not in (1, 2, 3): - raise ValueError("`spatial_dims` can only be 1, 2 or 3.") - - if resolution is not None: - if not isinstance(resolution, (list, tuple)): - raise TypeError("resolution must be a tuple") - elif not all(r > 0 for r in resolution): - raise ValueError("resolution must be positive") - # ensure normalization had affine trainable parameters (if not specified) norm = split_args(norm) if has_option(Norm[norm[0], spatial_dims], "affine"): @@ -517,9 +508,7 @@ def __init__( level_auto["head"] = nn.Identity() self.up_layers_auto.append(level_auto) - if ( - n_up == 0 - ): # in a corner case of flat structure (no downsampling), attache a single head + if n_up == 0: # in a corner case of flat structure (no downsampling), attache a single head level_auto = nn.ModuleDict( { "upsample": nn.Identity(), @@ -598,6 +587,11 @@ def forward( return self._forward(x, with_point, with_label) def set_auto_grad(self, auto_freeze=False, point_freeze=False): + """ + Args: + auto_freeze: if true, freeze the image encoder and the auto-branch. + point_freeze: if true, freeze the image encoder and the point-branch. + """ for param in self.encoder.parameters(): param.requires_grad = (not auto_freeze) and (not point_freeze) diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 64ff066733..6a6c7b32b9 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -21,11 +21,10 @@ import monai from monai.networks.blocks import UnetrBasicBlock +from monai.networks.blocks import MLPBlock from monai.transforms.utils import get_largest_connected_component_mask_point as lcc from monai.transforms.utils import convert_points_to_disc, sample_points_from_label - -from scripts.utils.workflow_utils import sample_points_patch_val - +from monai.utils import optional_import, unsqueeze_left, unsqueeze_right rearrange, _ = optional_import("einops", name="rearrange") @@ -126,16 +125,10 @@ def update_point_to_patch(self, patch_coords, point_coords, point_labels): patch_coords[-1].start, ] # update point coords - patch_starts = ( - torch.tensor(patch_starts, device=point_coords.device) - .unsqueeze(0) - .unsqueeze(0) - ) - patch_ends = ( - torch.tensor(patch_ends, device=point_coords.device) - .unsqueeze(0) - .unsqueeze(0) - ) + patch_starts = unsqueeze_left(torch.tensor(patch_starts, + device=point_coords.device), 5) + patch_ends = unsqueeze_left(torch.tensor(patch_ends, + device=point_coords.device), 5) # [1 N 1] indices = torch.logical_and( ((point_coords - patch_starts) > 0).all(2), @@ -173,15 +166,8 @@ def connected_components_combine( inside.append( np.any( [ - _logits[ - i, - 0, - round(p[0].item()), - round(p[1].item()), - round(p[2].item()), - ].item() - > 0 - for p in point_coords[i] + _logits[i,0,p[0],p[1],p[2]].item() > 0 + for p in point_coords[i].cpu().numpy().round() ] ) ) @@ -190,10 +176,7 @@ def connected_components_combine( _logits = torch.nan_to_num(_logits, nan=self.NINF_VALUE).sigmoid() pos_region = point_logits.sigmoid() > thred diff_pos = torch.logical_and( - torch.logical_or( - (_logits <= thred), - inside.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1), - ), + torch.logical_or(_logits <= thred, unsqueeze_right(inside, 5)), pos_region, ) diff_neg = torch.logical_and((_logits > thred), ~pos_region) @@ -202,7 +185,8 @@ def connected_components_combine( ) # cc is the region that can be updated by point_logits. cc = cc.to(logits.device) - # Need to replace NaN with point_logits. diff_neg will never lie in nan_mask, only remove unconnected positive region. + # Need to replace NaN with point_logits. diff_neg will never lie in nan_mask, + # only remove unconnected positive region. uc_pos_region = torch.logical_and(pos_region, ~cc) fill_mask = torch.logical_and(nan_mask, uc_pos_region) if fill_mask.any(): @@ -212,7 +196,6 @@ def connected_components_combine( cc = torch.logical_or(nan_mask, cc).to(logits.dtype) logits[mapping_index] *= 1 - cc logits[mapping_index] += cc * point_logits - # debug_ccp(_logits, point_logits.sigmoid(), point_coords, point_labels, diff, cc, logits[mapping_index], np.random.randint(10000)) return logits def gaussian_combine( @@ -303,7 +286,7 @@ def forward( image_size = input_images.shape[-3:] device = input_images.device if point_coords is None and class_vector is None: - return NINF_VALUE + torch.zeros([1, 1, *image_size], device=device) + return self.NINF_VALUE + torch.zeros([1, 1, *image_size], device=device) bs = self.get_bs(class_vector, point_coords) if patch_coords is not None: @@ -311,7 +294,7 @@ def forward( if labels is not None and label_set is not None: # if labels is not None, sample from labels for each patch. if val_point_sampler is None: - val_point_sampler = sample_points_patch_val + val_point_sampler = self.sample_points_patch_val point_coords, point_labels, prompt_class = val_point_sampler( labels, patch_coords, label_set ) @@ -552,11 +535,8 @@ def forward(self, out, point_coords, point_labels, class_vector=None): src = src.transpose(1, 2).view(b, c, h, w, d) upscaled_embedding = self.output_upscaling(src) b, c, h, w, d = upscaled_embedding.shape - masks.append( - (hyper_in @ upscaled_embedding.view(b, c, h * w * d)).view( - b, -1, h, w, d - ) - ) + mask = hyper_in @ upscaled_embedding.view(b, c, h * w * d) + masks.append(mask.view(-1, 1, h, w, d)) masks = torch.vstack(masks) return masks diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 05706984b8..925ccb1113 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -66,6 +66,8 @@ min_version, optional_import, pytorch_after, + unsqueeze_right, + unsqueeze_left ) from monai.utils.enums import TransformBackends from monai.utils.type_conversion import ( @@ -1212,17 +1214,12 @@ def get_largest_connected_component_mask_point( # if -1 padding point, skip continue for margin in range(margins): - left, right = max(p[0].round().int().item() - margin, 0), min( - p[0].round().int().item() + margin + 1, features.shape[-3] - ) - t, d = max(p[1].round().int().item() - margin, 0), min( - p[1].round().int().item() + margin + 1, features.shape[-2] - ) - f, b = max(p[2].round().int().item() - margin, 0), min( - p[2].round().int().item() + margin + 1, features.shape[-1] - ) - if (features[bs, 0, left:right, t:d, f:b] > 0).any(): - index = features[bs, 0, left:right, t:d, f:b].max() + x, y, z = p.round().int().tolist() + l, r = max(x - margin, 0), min(x + margin + 1, features.shape[-3]) + t, d = max(y - margin, 0), min(y + margin + 1, features.shape[-2]) + f, b = max(z - margin, 0), min(z + margin + 1, features.shape[-1]) + if (features[bs, 0, l:r, t:d, f:b] > 0).any(): + index = features[bs, 0, l:r, t:d, f:b].max() outs[[bs]] += lib.isin(features[[bs]], index) break outs[outs > 1] = 1 @@ -1232,11 +1229,12 @@ def convert_points_to_disc(image_size, point, point_label, radius=2, disc=False) """ Convert a 3D point coordinates into image mask. The returned mask has the same spatial size as `image_size` while the batch dimension is the same as point' batch dimension. - The point is converted to a mask ball with radius defined by `radius`. + The point is converted to a mask ball with radius defined by `radius`. The output + contains two channels each for negative (first channel) and positive points. Args: image_size: The output size of th point: [b, N, 3] - point_label: [b, N] + point_label: [b, N], 0 or 2 means negative points, 1 or 3 means postive points. radius: disc ball radius size disc: If true, use regular disc other other use gaussian. """ @@ -1246,45 +1244,23 @@ def convert_points_to_disc(image_size, point, point_label, radius=2, disc=False) [point.shape[0], 2, image_size[0], image_size[1], image_size[2]], device=point.device, ) - row_array = torch.arange( - start=0, end=image_size[0], step=1, dtype=torch.float32, device=point.device - ) - col_array = torch.arange( - start=0, end=image_size[1], step=1, dtype=torch.float32, device=point.device - ) - z_array = torch.arange( - start=0, end=image_size[2], step=1, dtype=torch.float32, device=point.device - ) - coord_rows, coord_cols, coord_z = torch.meshgrid(z_array, col_array, row_array) - # [1,3,h,w,d] -> [b, 2, 3, h,w,d] - coords = ( - torch.stack((coord_rows, coord_cols, coord_z), dim=0) - .unsqueeze(0) - .unsqueeze(0) - .repeat(point.shape[0], 2, 1, 1, 1, 1) - ) + _array = [torch.arange( + start=0, end=image_size[i], step=1, dtype=torch.float32, device=point.device + ) for i in range(3)] + coord_rows, coord_cols, coord_z = torch.meshgrid(_array[2], _array[1], _array[0]) + # [1, 3, h, w, d] -> [b, 2, 3, h, w, d] + coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6) + coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1) for b in range(point.shape[0]): for n in range(point.shape[1]): + point_bn = unsqueeze_right(point[b, n], 6) if point_label[b, n] > -1: channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1 + pow_diff = torch.pow(coords[b, channel] - point_bn[b, n], 2) if disc: - masks[b, channel] += ( - torch.pow( - coords[b, channel] - - point[b, n].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1), - 2, - ).sum(0) - < radius**2 - ) + masks[b, channel] += pow_diff.sum(0) < radius**2 else: - masks[b, channel] += torch.exp( - -torch.pow( - coords[b, channel] - - point[b, n].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1), - 2, - ).sum(0) - / (2 * radius**2) - ) + masks[b, channel] += torch.exp(-pow_diff.sum(0) / (2 * radius**2)) return masks def sample_points_from_label( @@ -1307,8 +1283,6 @@ def sample_points_from_label( unique_labels = labels.unique().cpu().numpy().tolist() _point = [] _point_label = [] - Nn = max_npoint - Np = max_ppoint for id in label_set: if id in unique_labels: plabels = labels == int(id) @@ -1318,71 +1292,28 @@ def sample_points_from_label( if len(plabelpoints) == 0: plabelpoints = torch.nonzero(plabels).to(device) nlabelpoints = torch.nonzero(nlabels).to(device) + Np = min(len(plabelpoints), max_ppoint) + Nn = min(len(nlabelpoints), max_npoint) + pad = max_ppoint + max_npoint - Np - Nn if use_center: pmean = plabelpoints.float().mean(0) pdis = ((plabelpoints - pmean) ** 2).sum(-1) _, sorted_indices = torch.sort(pdis) - _point.append( - torch.stack( - [ - plabelpoints[sorted_indices[i]] - for i in range(min(len(plabelpoints), Np)) - ] - + random.choices(nlabelpoints, k=min(len(nlabelpoints), Nn)) - + [torch.tensor([0, 0, 0], device=device)] - * ( - Np - + Nn - - min(len(plabelpoints), Np) - - min(len(nlabelpoints), Nn) - ) - ) - ) - _point_label.append( - torch.tensor( - [1] * min(len(plabelpoints), Np) - + [0.0] * min(len(nlabelpoints), Nn) - + [-1] - * ( - Np - + Nn - - min(len(plabelpoints), Np) - - min(len(nlabelpoints), Nn) - ) - ).to(device) - ) - else: - _point.append( - torch.stack( - random.choices(plabelpoints, k=min(len(plabelpoints), Np)) - + random.choices(nlabelpoints, k=min(len(nlabelpoints), Nn)) - + [torch.tensor([0, 0, 0], device=device)] - * ( - Np - + Nn - - min(len(plabelpoints), Np) - - min(len(nlabelpoints), Nn) - ) + sorted_indices = list(range(len(plabelpoints))) + random.shuffle(sorted_indices) + _point.append( + torch.stack([plabelpoints[sorted_indices[i]] for i in Np] + + random.choices(nlabelpoints, k=Nn) + + [torch.tensor([0, 0, 0], device=device)] * pad ) ) - _point_label.append( - torch.tensor( - [1] * min(len(plabelpoints), Np) - + [0.0] * min(len(nlabelpoints), Nn) - + [-1] - * ( - Np - + Nn - - min(len(plabelpoints), Np) - - min(len(nlabelpoints), Nn) - ) - ).to(device) - ) + _point_label.append( + torch.tensor([1] * Np + [0] * Nn + [-1] * pad).to(device)) else: # pad the background labels - _point.append(torch.zeros(Np + Nn, 3).to(device)) # all 0 - _point_label.append(torch.zeros(Np + Nn).to(device) - 1) # -1 not a point + _point.append(torch.zeros(max_ppoint + max_npoint, 3).to(device)) + _point_label.append(torch.zeros(max_ppoint + max_npoint).to(device) - 1) point = torch.stack(_point) point_label = torch.stack(_point_label) return point, point_label From 91815bbcb110266e063b40a5f06900647816e40a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 19:36:40 +0000 Subject: [PATCH 04/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/vista3d.py | 6 +++--- monai/transforms/utils.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 6a6c7b32b9..6c99afd1b0 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -125,9 +125,9 @@ def update_point_to_patch(self, patch_coords, point_coords, point_labels): patch_coords[-1].start, ] # update point coords - patch_starts = unsqueeze_left(torch.tensor(patch_starts, + patch_starts = unsqueeze_left(torch.tensor(patch_starts, device=point_coords.device), 5) - patch_ends = unsqueeze_left(torch.tensor(patch_ends, + patch_ends = unsqueeze_left(torch.tensor(patch_ends, device=point_coords.device), 5) # [1 N 1] indices = torch.logical_and( @@ -185,7 +185,7 @@ def connected_components_combine( ) # cc is the region that can be updated by point_logits. cc = cc.to(logits.device) - # Need to replace NaN with point_logits. diff_neg will never lie in nan_mask, + # Need to replace NaN with point_logits. diff_neg will never lie in nan_mask, # only remove unconnected positive region. uc_pos_region = torch.logical_and(pos_region, ~cc) fill_mask = torch.logical_and(nan_mask, uc_pos_region) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 925ccb1113..a51be1b815 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1313,7 +1313,7 @@ def sample_points_from_label( else: # pad the background labels _point.append(torch.zeros(max_ppoint + max_npoint, 3).to(device)) - _point_label.append(torch.zeros(max_ppoint + max_npoint).to(device) - 1) + _point_label.append(torch.zeros(max_ppoint + max_npoint).to(device) - 1) point = torch.stack(_point) point_label = torch.stack(_point_label) return point, point_label From 6710b2334c376bbd41cb9d3cef2f07cd1b8acbd9 Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Thu, 8 Aug 2024 16:05:00 -0400 Subject: [PATCH 05/30] Update docstring Signed-off-by: heyufan1995 --- monai/networks/nets/__init__.py | 2 +- monai/networks/nets/segresnet_ds.py | 2 +- monai/networks/nets/vista3d.py | 31 ++++++++++++++++++++++++++--- monai/transforms/utils.py | 23 ++++++++++++--------- 4 files changed, 44 insertions(+), 14 deletions(-) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index c777fe6442..35c517bdae 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -76,7 +76,7 @@ resnet200, ) from .segresnet import SegResNet, SegResNetVAE -from .segresnet_ds import SegResNetDS +from .segresnet_ds import SegResNetDS, SegResNetDS2 from .senet import ( SENet, SEnet, diff --git a/monai/networks/nets/segresnet_ds.py b/monai/networks/nets/segresnet_ds.py index be8d90c49c..2cd32280fd 100644 --- a/monai/networks/nets/segresnet_ds.py +++ b/monai/networks/nets/segresnet_ds.py @@ -23,7 +23,7 @@ from monai.networks.layers.utils import get_act_layer, get_norm_layer from monai.utils import UpsampleMode, has_option -__all__ = ["SegResNetDS"] +__all__ = ["SegResNetDS", "SegResNetDS2"] def scales_for_resolution(resolution: tuple | list, n_stages: int | None = None): diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 6c99afd1b0..0b76a2035c 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -22,6 +22,7 @@ import monai from monai.networks.blocks import UnetrBasicBlock from monai.networks.blocks import MLPBlock +from monai.networks.nets import SegResNetDS2 from monai.transforms.utils import get_largest_connected_component_mask_point as lcc from monai.transforms.utils import convert_points_to_disc, sample_points_from_label from monai.utils import optional_import, unsqueeze_left, unsqueeze_right @@ -152,10 +153,18 @@ def update_point_to_patch(self, patch_coords, point_coords, point_labels): def connected_components_combine( self, logits, point_logits, point_coords, point_labels, mapping_index, thred=0.5 ): - """Combine auto results with point click response, or combine previous mask with point click response. + """ Combine auto results with point click response. The auto results have shape [B, 1, H, W, D] which means B foreground masks + from a single image patch. Out of those B foreground masks, user may add points to a subset of B1 foreground masks for editing. + mapping_index represents the correspondence between B and B1. For mapping_index with point clicks, NaN values in logits will be replaced with point_logits. Meanwhile, the added/removed region in point clicks must be updated by the lcc function. Notice, if a positive point is within logits/prev_mask, the components containing the positive point will be added. + Args: + logits: automatic branch results, [B, 1, H, W, D]. + point_logits: point branch results, [B1, 1, H, W, D]. + point_coords: point coordinates, [B1, N, 3]. + point_labels: point labels, [B1, N]. + mapping_index: [B]. """ logits = ( logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits @@ -173,6 +182,7 @@ def connected_components_combine( ) inside = torch.tensor(inside).to(logits.device) nan_mask = torch.isnan(_logits) + # _logits are converted to binary [B1, 1, H, W, D] _logits = torch.nan_to_num(_logits, nan=self.NINF_VALUE).sigmoid() pos_region = point_logits.sigmoid() > thred diff_pos = torch.logical_and( @@ -383,7 +393,7 @@ def forward( self.image_embeddings = out.detach() return logits -class Point_Mapping_SAM(nn.Module): +class PointMappingSAM(nn.Module): def __init__( self, feature_size, @@ -541,7 +551,7 @@ def forward(self, out, point_coords, point_labels, class_vector=None): return masks -class Class_Mapping_Classify(nn.Module): +class ClassMappingClassify(nn.Module): def __init__(self, n_classes, feature_size, use_mlp=False): super().__init__() self.use_mlp = use_mlp @@ -590,6 +600,21 @@ def forward(self, src, class_vector): masks = torch.cat(masks, 1) return masks, class_embedding +def VISTA3D132(encoder_embed_dim=48, in_channels=1): + segresnet = SegResNetDS2( + in_channels=in_channels, + blocks_down=(1, 2, 2, 4, 4), + norm="instance", + out_channels=encoder_embed_dim, + init_filters=encoder_embed_dim, + dsdepth=1, + ) + point_head = PointMappingSAM(feature_size=encoder_embed_dim, n_classes=512, last_supported=132) + class_head = ClassMappingClassify(n_classes=512, feature_size=encoder_embed_dim, use_mlp=True) + vista = VISTA3D( + image_encoder=segresnet, class_head=class_head, point_head=point_head, feature_size=encoder_embed_dim + ) + return vista # Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index a51be1b815..05d5a1e996 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1179,20 +1179,25 @@ def get_largest_connected_component_mask( return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0] def get_largest_connected_component_mask_point( - img_pos: NdarrayTensor, - img_neg: NdarrayTensor, - pos_val: list=[1, 3], - neg_val: list=[0, 2], + img_pos: bool = NdarrayTensor, + img_neg: bool = NdarrayTensor, + pos_val: list = [1, 3], + neg_val: list = [0, 2], point_coords: None = None, point_labels: None = None, margins: int = 3, ) -> NdarrayTensor: """ - Gets the largest connected component mask of an image that include the point_coords. + Gets the connected component of img_pos and img_neg that include the positive points and + negative points separately. The function is used for combining automatic results with interactive + results in VISTA3D. Args: - img_pos: [1, B, H, W, D] - point_coords [B, N, 3] - point_labels [B, N] + img_pos: bool type array. [B, 1, H, W, D]. B foreground masks from a single 3D image. + img_neg: same format as img_pos but corresponds to negative points. + pos_val: positive point label values. + neg_val: negative point label values. + point_coords: [B, N, 3] + point_labels: [B, N] """ img_pos_, *_ = convert_data_type(img_pos, np.ndarray) @@ -1303,7 +1308,7 @@ def sample_points_from_label( sorted_indices = list(range(len(plabelpoints))) random.shuffle(sorted_indices) _point.append( - torch.stack([plabelpoints[sorted_indices[i]] for i in Np] + torch.stack([plabelpoints[sorted_indices[i]] for i in range(Np)] + random.choices(nlabelpoints, k=Nn) + [torch.tensor([0, 0, 0], device=device)] * pad ) From ca6acb8d656d10d34aeecc3d6f56d6c25830d090 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Aug 2024 20:05:27 +0000 Subject: [PATCH 06/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 05d5a1e996..2605f09d9d 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1188,13 +1188,13 @@ def get_largest_connected_component_mask_point( margins: int = 3, ) -> NdarrayTensor: """ - Gets the connected component of img_pos and img_neg that include the positive points and + Gets the connected component of img_pos and img_neg that include the positive points and negative points separately. The function is used for combining automatic results with interactive - results in VISTA3D. + results in VISTA3D. Args: img_pos: bool type array. [B, 1, H, W, D]. B foreground masks from a single 3D image. img_neg: same format as img_pos but corresponds to negative points. - pos_val: positive point label values. + pos_val: positive point label values. neg_val: negative point label values. point_coords: [B, N, 3] point_labels: [B, N] From 42606e0615bf649aa800f3b7b50335ba35337b83 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 9 Aug 2024 16:59:31 +0800 Subject: [PATCH 07/30] rewrite segresnetds2 Signed-off-by: Yiheng Wang --- monai/networks/nets/segresnet_ds.py | 114 +++++----------------------- 1 file changed, 21 insertions(+), 93 deletions(-) diff --git a/monai/networks/nets/segresnet_ds.py b/monai/networks/nets/segresnet_ds.py index 2cd32280fd..65ad5a0bbc 100644 --- a/monai/networks/nets/segresnet_ds.py +++ b/monai/networks/nets/segresnet_ds.py @@ -11,6 +11,7 @@ from __future__ import annotations +import copy from collections.abc import Callable from typing import Union @@ -429,8 +430,10 @@ def forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tenso class SegResNetDS2(SegResNetDS): """ - SegResNetDS2 is the image encoder used by VISTA3D. It adds one additional decoder branch. + SegResNetDS2 based on `SegResNetDS` and adds an additional decorder branch. + It is the image encoder used by VISTA3D. """ + def __init__( self, spatial_dims: int = 3, @@ -447,92 +450,30 @@ def __init__( resolution: tuple | None = None, ): super().__init__( - spatial_dims = spatial_dims, + spatial_dims=spatial_dims, init_filters=init_filters, in_channels=in_channels, - out_channels= out_channels, - act = act, - norm = norm, - blocks_down = blocks_down, - blocks_up = blocks_up, - dsdepth = dsdepth, - preprocess = preprocess, - upsample_mode = upsample_mode, - resolution = resolution) - - # ensure normalization had affine trainable parameters (if not specified) - norm = split_args(norm) - if has_option(Norm[norm[0], spatial_dims], "affine"): - norm[1].setdefault("affine", True) # type: ignore - - # ensure activation is inplace (if not specified) - act = split_args(act) - if has_option(Act[act[0]], "inplace"): - act[1].setdefault("inplace", True) # type: ignore - - n_up = len(blocks_down) - 1 - - filters = init_filters * 2**n_up - self.up_layers_auto = nn.ModuleList() - - # self.anisotropic_scales and self.blocks_up are created within super().init() - - for i in range(n_up): - filters = filters // 2 - kernel_size, _, stride = ( - aniso_kernel(self.anisotropic_scales[len(self.blocks_up) - i - 1]) - if self.anisotropic_scales - else (3, 1, 2) - ) - - level_auto = nn.ModuleDict() - blocks = [ - SegResBlock( - spatial_dims=spatial_dims, - in_channels=filters, - kernel_size=kernel_size, - norm=norm, - act=act, - ) - for _ in range(self.blocks_up[i]) - ] - level_auto["blocks"] = nn.Sequential(*blocks) - if len(self.blocks_up) - i <= dsdepth: # deep supervision heads - level_auto["head"] = Conv[Conv.CONV, spatial_dims]( - in_channels=filters, - out_channels=out_channels, - kernel_size=1, - bias=True, - ) - else: - level_auto["head"] = nn.Identity() - self.up_layers_auto.append(level_auto) + out_channels=out_channels, + act=act, + norm=norm, + blocks_down=blocks_down, + blocks_up=blocks_up, + dsdepth=dsdepth, + preprocess=preprocess, + upsample_mode=upsample_mode, + resolution=resolution, + ) - if n_up == 0: # in a corner case of flat structure (no downsampling), attache a single head - level_auto = nn.ModuleDict( - { - "upsample": nn.Identity(), - "blocks": nn.Identity(), - "head": Conv[Conv.CONV, spatial_dims]( - in_channels=filters, - out_channels=out_channels, - kernel_size=1, - bias=True, - ), - } - ) - self.up_layers_auto.append(level_auto) + self.up_layers_auto = nn.ModuleList([copy.deepcopy(layer) for layer in self.up_layers]) - def _forward( - self, x: torch.Tensor, with_point, with_label - ) -> Union[None, torch.Tensor, list[torch.Tensor]]: + def forward( # type: ignore + self, x: torch.Tensor, with_point: bool = True, with_label: bool = True, **kwargs + ) -> tuple[Union[None, torch.Tensor, list[torch.Tensor]], Union[None, torch.Tensor, list[torch.Tensor]]]: if self.preprocess is not None: x = self.preprocess(x) if not self.is_valid_shape(x): - raise ValueError( - f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}" - ) + raise ValueError(f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}") x_down = self.encoder(x) @@ -571,20 +512,7 @@ def _forward( outputs_auto.reverse() - # in eval() mode, always return a single final output - if not self.training or len(outputs) == 1: - outputs = outputs[0] if len(outputs) == 1 else outputs - - if not self.training or len(outputs_auto) == 1: - outputs_auto = outputs_auto[0] if len(outputs_auto) == 1 else outputs_auto - - # return a list of DS outputs - return outputs, outputs_auto - - def forward( - self, x: torch.Tensor, with_point=True, with_label=True, **kwargs - ) -> Union[None, torch.Tensor, list[torch.Tensor]]: - return self._forward(x, with_point, with_label) + return outputs[0] if len(outputs) == 1 else outputs, outputs_auto[0] if len(outputs_auto) == 1 else outputs_auto def set_auto_grad(self, auto_freeze=False, point_freeze=False): """ From 5c071bafd18176405c20da9ff76e8b590e10c228 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 9 Aug 2024 18:07:24 +0800 Subject: [PATCH 08/30] update segresnet ds doc Signed-off-by: Yiheng Wang --- monai/networks/nets/segresnet_ds.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/monai/networks/nets/segresnet_ds.py b/monai/networks/nets/segresnet_ds.py index 65ad5a0bbc..a7c7e7cd94 100644 --- a/monai/networks/nets/segresnet_ds.py +++ b/monai/networks/nets/segresnet_ds.py @@ -469,6 +469,12 @@ def __init__( def forward( # type: ignore self, x: torch.Tensor, with_point: bool = True, with_label: bool = True, **kwargs ) -> tuple[Union[None, torch.Tensor, list[torch.Tensor]], Union[None, torch.Tensor, list[torch.Tensor]]]: + """ + Args: + x: input tensor. + with_point: if true, return the point branch output. + with_label: if true, return the label branch output. + """ if self.preprocess is not None: x = self.preprocess(x) From a82b44c10c25d917e896a842efabd2894cbef1cc Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 9 Aug 2024 22:25:50 +0800 Subject: [PATCH 09/30] replace mlpblock Signed-off-by: Yiheng Wang --- monai/networks/nets/vista3d.py | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 0b76a2035c..03ca4dfbd1 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -642,7 +642,7 @@ def __init__( embedding_dim: int, num_heads: int, mlp_dim: int, - activation: Type[nn.Module] = nn.ReLU, + activation: tuple | str = "relu", attention_downsample_rate: int = 2, ) -> None: """ @@ -734,7 +734,7 @@ def __init__( embedding_dim: int, num_heads: int, mlp_dim: int = 2048, - activation: Type[nn.Module] = nn.ReLU, + activation: tuple | str = "relu", attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, ) -> None: @@ -760,7 +760,7 @@ def __init__( ) self.norm2 = nn.LayerNorm(embedding_dim) - self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation, dropout_mode="vista3d") self.norm3 = nn.LayerNorm(embedding_dim) self.norm4 = nn.LayerNorm(embedding_dim) @@ -915,22 +915,6 @@ def forward_with_coords( return self._pe_encoding(coords.to(torch.float)) # B x N x C -class MLPBlock(nn.Module): - def __init__( - self, - embedding_dim: int, - mlp_dim: int, - act: Type[nn.Module] = nn.GELU, - ) -> None: - super().__init__() - self.lin1 = nn.Linear(embedding_dim, mlp_dim) - self.lin2 = nn.Linear(mlp_dim, embedding_dim) - self.act = act() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.lin2(self.act(self.lin1(x))) - - class MLP(nn.Module): def __init__( self, From 6faf80698d1966021e03a0c685f080799b9bda6b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 9 Aug 2024 14:26:20 +0000 Subject: [PATCH 10/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/vista3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 03ca4dfbd1..9d1fae1223 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -12,7 +12,7 @@ from __future__ import annotations import math -from typing import Any, Optional, Tuple, Type +from typing import Any, Optional, Tuple from torch import Tensor, nn import numpy as np From 2c06d173418458e1ee83d8e13cb9e94722c0a058 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 9 Aug 2024 22:41:47 +0800 Subject: [PATCH 11/30] resolve conflicts Signed-off-by: Yiheng Wang --- monai/transforms/utils.py | 175 -------------------------------------- 1 file changed, 175 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e805591573..e32bf6fc48 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1177,181 +1177,6 @@ def get_largest_connected_component_mask( return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0] -def get_largest_connected_component_mask_point( - img_pos: bool = NdarrayTensor, - img_neg: bool = NdarrayTensor, - pos_val: list = [1, 3], - neg_val: list = [0, 2], - point_coords: None = None, - point_labels: None = None, - margins: int = 3, -) -> NdarrayTensor: - """ - Gets the connected component of img_pos and img_neg that include the positive points and - negative points separately. The function is used for combining automatic results with interactive - results in VISTA3D. - Args: - img_pos: bool type array. [B, 1, H, W, D]. B foreground masks from a single 3D image. - img_neg: same format as img_pos but corresponds to negative points. - pos_val: positive point label values. - neg_val: negative point label values. - point_coords: [B, N, 3] - point_labels: [B, N] - """ - - img_pos_, *_ = convert_data_type(img_pos, np.ndarray) - img_neg_, *_ = convert_data_type(img_neg, np.ndarray) - label = measure.label - lib = np - - features_pos, num_features = label(img_pos_, connectivity=3, return_num=True) - features_neg, num_features = label(img_neg_, connectivity=3, return_num=True) - - outs = np.zeros_like(img_pos_) - for bs in range(point_coords.shape[0]): - for i, p in enumerate(point_coords[bs]): - if point_labels[bs, i] in pos_val: - features = features_pos - elif point_labels[bs, i] in neg_val: - features = features_neg - else: - # if -1 padding point, skip - continue - for margin in range(margins): - x, y, z = p.round().int().tolist() - l, r = max(x - margin, 0), min(x + margin + 1, features.shape[-3]) - t, d = max(y - margin, 0), min(y + margin + 1, features.shape[-2]) - f, b = max(z - margin, 0), min(z + margin + 1, features.shape[-1]) - if (features[bs, 0, l:r, t:d, f:b] > 0).any(): - index = features[bs, 0, l:r, t:d, f:b].max() - outs[[bs]] += lib.isin(features[[bs]], index) - break - outs[outs > 1] = 1 - return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0] - -def convert_points_to_disc(image_size, point, point_label, radius=2, disc=False): - """ - Convert a 3D point coordinates into image mask. The returned mask has the same spatial - size as `image_size` while the batch dimension is the same as point' batch dimension. - The point is converted to a mask ball with radius defined by `radius`. The output - contains two channels each for negative (first channel) and positive points. - Args: - image_size: The output size of th - point: [b, N, 3] - point_label: [b, N], 0 or 2 means negative points, 1 or 3 means postive points. - radius: disc ball radius size - disc: If true, use regular disc other other use gaussian. - """ - if not torch.is_tensor(point): - point = torch.from_numpy(point) - masks = torch.zeros( - [point.shape[0], 2, image_size[0], image_size[1], image_size[2]], - device=point.device, - ) - _array = [torch.arange( - start=0, end=image_size[i], step=1, dtype=torch.float32, device=point.device - ) for i in range(3)] - coord_rows, coord_cols, coord_z = torch.meshgrid(_array[2], _array[1], _array[0]) - # [1, 3, h, w, d] -> [b, 2, 3, h, w, d] - coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6) - coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1) - for b in range(point.shape[0]): - for n in range(point.shape[1]): - point_bn = unsqueeze_right(point[b, n], 6) - if point_label[b, n] > -1: - channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1 - pow_diff = torch.pow(coords[b, channel] - point_bn[b, n], 2) - if disc: - masks[b, channel] += pow_diff.sum(0) < radius**2 - else: - masks[b, channel] += torch.exp(-pow_diff.sum(0) / (2 * radius**2)) - return masks - -def sample_points_from_label( - labels, label_set=None, max_ppoint=1, max_npoint=0, device="cpu", use_center=False -): - """Sample points from labels. - Args: - labels: [1, 1, H, W, D] - label_set: local index, must match values in labels. - max_ppoint: maximum positive point samples. - max_npoint: maximum negative point samples. - device: returned tensor device. - use_center: whether to sample points from center. - Returns: - point: point coordinates of [B, N, 3]. - point_label: [B, N], always 0 for negative, 1 for positive. - """ - assert labels.shape[0] == 1, "only support batch size 1" - labels = labels[0, 0] - unique_labels = labels.unique().cpu().numpy().tolist() - _point = [] - _point_label = [] - for id in label_set: - if id in unique_labels: - plabels = labels == int(id) - nlabels = ~plabels - _plabels = get_largest_connected_component_mask(erode3d(plabels)) - plabelpoints = torch.nonzero(_plabels).to(device) - if len(plabelpoints) == 0: - plabelpoints = torch.nonzero(plabels).to(device) - nlabelpoints = torch.nonzero(nlabels).to(device) - Np = min(len(plabelpoints), max_ppoint) - Nn = min(len(nlabelpoints), max_npoint) - pad = max_ppoint + max_npoint - Np - Nn - if use_center: - pmean = plabelpoints.float().mean(0) - pdis = ((plabelpoints - pmean) ** 2).sum(-1) - _, sorted_indices = torch.sort(pdis) - else: - sorted_indices = list(range(len(plabelpoints))) - random.shuffle(sorted_indices) - _point.append( - torch.stack([plabelpoints[sorted_indices[i]] for i in range(Np)] - + random.choices(nlabelpoints, k=Nn) - + [torch.tensor([0, 0, 0], device=device)] * pad - ) - ) - _point_label.append( - torch.tensor([1] * Np + [0] * Nn + [-1] * pad).to(device)) - else: - # pad the background labels - _point.append(torch.zeros(max_ppoint + max_npoint, 3).to(device)) - _point_label.append(torch.zeros(max_ppoint + max_npoint).to(device) - 1) - point = torch.stack(_point) - point_label = torch.stack(_point_label) - return point, point_label - -def erode3d(input_tensor, erosion=3): - # Define the structuring element - erosion = ensure_tuple_rep(erosion, 3) - structuring_element = torch.ones(1, 1, erosion[0], erosion[1], erosion[2]).to( - input_tensor.device - ) - - # Pad the input tensor to handle border pixels - input_padded = F.pad( - input_tensor.float().unsqueeze(0).unsqueeze(0), - ( - erosion[2] // 2, - erosion[2] // 2, - erosion[1] // 2, - erosion[1] // 2, - erosion[0] // 2, - erosion[0] // 2, - ), - mode="constant", - value=1.0, - ) - - # Apply erosion operation - output = F.conv3d(input_padded, structuring_element, padding=0) - - # Set output values based on the minimum value within the structuring element - output = torch.where(output == torch.sum(structuring_element), 1.0, 0.0) - - return output.squeeze(0).squeeze(0) - def get_largest_connected_component_mask_point( img_pos: NdarrayTensor, From 6da1a843d946881b6d6133a1b925176fd787474f Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 9 Aug 2024 22:55:09 +0800 Subject: [PATCH 12/30] fix arg naming error Signed-off-by: Yiheng Wang --- monai/networks/nets/vista3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 9d1fae1223..f934014b77 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -760,7 +760,7 @@ def __init__( ) self.norm2 = nn.LayerNorm(embedding_dim) - self.mlp = MLPBlock(embedding_dim, mlp_dim, activation, dropout_mode="vista3d") + self.mlp = MLPBlock(hidden_size=embedding_dim, mlp_dim=mlp_dim, act=activation, dropout_mode="vista3d") self.norm3 = nn.LayerNorm(embedding_dim) self.norm4 = nn.LayerNorm(embedding_dim) From f48a475d09341d19a6e9f1d5d64de21db8c899b1 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 9 Aug 2024 23:05:16 +0800 Subject: [PATCH 13/30] add to init Signed-off-by: Yiheng Wang --- monai/networks/nets/__init__.py | 1 + monai/networks/nets/vista3d.py | 8 ++------ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 35c517bdae..748d73bd62 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -120,6 +120,7 @@ from .varautoencoder import VarAutoEncoder from .vit import ViT from .vitautoenc import ViTAutoEnc +from .vista3d import VISTA3D, VISTA3D132 from .vnet import VNet from .voxelmorph import VoxelMorph, VoxelMorphUNet from .vqvae import VQVAE diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index f934014b77..a45f0f9d89 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -29,9 +29,7 @@ rearrange, _ = optional_import("einops", name="rearrange") -__all__ = ["VISTA3D"] - - +__all__ = ["VISTA3D", "VISTA3D132"] class VISTA3D(nn.Module): @@ -611,9 +609,7 @@ def VISTA3D132(encoder_embed_dim=48, in_channels=1): ) point_head = PointMappingSAM(feature_size=encoder_embed_dim, n_classes=512, last_supported=132) class_head = ClassMappingClassify(n_classes=512, feature_size=encoder_embed_dim, use_mlp=True) - vista = VISTA3D( - image_encoder=segresnet, class_head=class_head, point_head=point_head, feature_size=encoder_embed_dim - ) + vista = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head) return vista # Copyright (c) MONAI Consortium From ff9855eab4e2d7ad4c5bbc79f0033a357ee526a4 Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Fri, 9 Aug 2024 18:34:25 -0400 Subject: [PATCH 14/30] Update docstring and tested using gui Signed-off-by: heyufan1995 --- monai/networks/nets/vista3d.py | 41 +++++++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index a45f0f9d89..25ea692b84 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -53,6 +53,8 @@ def __init__(self, image_encoder, class_head, point_head): self.PINF_VALUE = 9999 def get_bs(self, class_vector, point_coords): + """ Get number of foreground classes based on class and point prompt. + """ if class_vector is None: assert point_coords is not None, "prompt is required" return point_coords.shape[0] @@ -61,6 +63,10 @@ def get_bs(self, class_vector, point_coords): def convert_point_label(self, point_label, label_set=None, special_index=[23, 24, 25, 26, 27, 57, 128]): + """ Convert point label based on its class prompt. For special classes defined in special index, + the positive/negative point label will be converted from 1/0 to 3/2. The purpose is to separate those + classes with ambiguous classes. + """ if label_set is None: return point_label assert point_label.shape[0] == len(label_set) @@ -112,6 +118,11 @@ def sample_points_patch_val( def update_point_to_patch(self, patch_coords, point_coords, point_labels): """ Update point_coords with respect to patch coords. If point is outside of the patch, remove the coordinates and set label to -1 + Args: + patch_coords: the python slice object representing the patch coordinates during sliding window inference. This value is + passed from sliding_window_inferer. + point_coords: point coordinates, [B, N, 3]. + point_labels: point labels, [B, N]. """ patch_ends = [ patch_coords[-3].stop, @@ -125,9 +136,9 @@ def update_point_to_patch(self, patch_coords, point_coords, point_labels): ] # update point coords patch_starts = unsqueeze_left(torch.tensor(patch_starts, - device=point_coords.device), 5) + device=point_coords.device), 2) patch_ends = unsqueeze_left(torch.tensor(patch_ends, - device=point_coords.device), 5) + device=point_coords.device), 2) # [1 N 1] indices = torch.logical_and( ((point_coords - patch_starts) > 0).all(2), @@ -174,7 +185,7 @@ def connected_components_combine( np.any( [ _logits[i,0,p[0],p[1],p[2]].item() > 0 - for p in point_coords[i].cpu().numpy().round() + for p in point_coords[i].cpu().numpy().round().astype(int) ] ) ) @@ -209,7 +220,15 @@ def connected_components_combine( def gaussian_combine( self, logits, point_logits, point_coords, point_labels, mapping_index, radius ): - """Combine point results with auto results using gaussian.""" + """ Combine point results with auto results using gaussian. + Args: + logits: automatic branch results, [B, 1, H, W, D]. + point_logits: point branch results, [B1, 1, H, W, D]. + point_coords: point coordinates, [B1, N, 3]. + point_labels: point labels, [B1, N]. + mapping_index: [B]. + radius: gaussian ball radius. + """ if radius is None: radius = min(point_logits.shape[-3:]) // 5 # empirical value 5 weight = 1 - convert_points_to_disc( @@ -224,7 +243,11 @@ def gaussian_combine( return logits def set_auto_grad(self, auto_freeze=False, point_freeze=False): - """Freeze auto-branch or point-branch""" + """Freeze auto-branch or point-branch. + Args: + auto_freeze: freeze the auto branch. + point_freeze: freeze the point branch. + """ if auto_freeze != self.auto_freeze: if hasattr(self.image_encoder, "set_auto_grad"): self.image_encoder.set_auto_grad( @@ -279,7 +302,7 @@ def forward( provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b] will be considered novel class. patch_coords: the python slice object representing the patch coordinates during sliding window inference. This value is - passed from monai_utils.sliding_window_inferer. This is an indicator for training phase or validation phase. + passed from sliding_window_inferer. This is an indicator for training phase or validation phase. labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID, this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot @@ -591,10 +614,8 @@ def forward(self, src, class_vector): # [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension. masks = [] for i in range(b): - mask = (class_embedding @ src[[i]].view(1, c, h * w * d)).view( - -1, 1, h, w, d - ) - masks.append(mask) + mask = (class_embedding @ src[[i]].view(1, c, h * w * d)) + masks.append(mask.view(-1, 1, h, w, d)) masks = torch.cat(masks, 1) return masks, class_embedding From c51446faf6f8782068d51a9490dcee2462f66b35 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 9 Aug 2024 22:35:24 +0000 Subject: [PATCH 15/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/vista3d.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 25ea692b84..5ad63532e9 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -65,7 +65,7 @@ def convert_point_label(self, point_label, label_set=None, special_index=[23, 24, 25, 26, 27, 57, 128]): """ Convert point label based on its class prompt. For special classes defined in special index, the positive/negative point label will be converted from 1/0 to 3/2. The purpose is to separate those - classes with ambiguous classes. + classes with ambiguous classes. """ if label_set is None: return point_label @@ -120,9 +120,9 @@ def update_point_to_patch(self, patch_coords, point_coords, point_labels): If point is outside of the patch, remove the coordinates and set label to -1 Args: patch_coords: the python slice object representing the patch coordinates during sliding window inference. This value is - passed from sliding_window_inferer. + passed from sliding_window_inferer. point_coords: point coordinates, [B, N, 3]. - point_labels: point labels, [B, N]. + point_labels: point labels, [B, N]. """ patch_ends = [ patch_coords[-3].stop, From 696f38307e53edca87ef2f2fe278e005adad9089 Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Mon, 12 Aug 2024 08:09:36 -0400 Subject: [PATCH 16/30] Minor docstring update Signed-off-by: heyufan1995 --- monai/networks/nets/vista3d.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 5ad63532e9..88dbe32a9e 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -97,7 +97,8 @@ def sample_points_patch_val( patch_coords: sliding window slice object label_set: local index, must match values in labels use_center: sample points from the center - mapped_label_set: global index, it is used to identify special classes. + mapped_label_set: global index, it is used to identify special classes and is the global index + for the sampled points. max_ppoint/max_npoint: positive points and negative points to sample. """ point_coords, point_labels = sample_points_from_label( @@ -620,6 +621,11 @@ def forward(self, src, class_vector): return masks, class_embedding def VISTA3D132(encoder_embed_dim=48, in_channels=1): + """ Exact VISTA3D network configuration used in https://arxiv.org/abs/2406.05285>`_. The model treats class index larger than 132 as zero-shot. + Args: + encoder_embed_dim: hidden dimension for encoder. + in_channels: input channel number. + """ segresnet = SegResNetDS2( in_channels=in_channels, blocks_down=(1, 2, 2, 4, 4), From aebe273808aeeb6440305506914cc0c14c884f50 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Aug 2024 12:11:10 +0000 Subject: [PATCH 17/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/vista3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 88dbe32a9e..9ea18329d7 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -97,7 +97,7 @@ def sample_points_patch_val( patch_coords: sliding window slice object label_set: local index, must match values in labels use_center: sample points from the center - mapped_label_set: global index, it is used to identify special classes and is the global index + mapped_label_set: global index, it is used to identify special classes and is the global index for the sampled points. max_ppoint/max_npoint: positive points and negative points to sample. """ From 2430f64aa105992c0d582466dfe59c9993257ea4 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 12 Aug 2024 21:25:45 +0800 Subject: [PATCH 18/30] fix code format issues Signed-off-by: Yiheng Wang --- docs/source/networks.rst | 5 + monai/networks/nets/__init__.py | 2 +- monai/networks/nets/vista3d.py | 615 +++++++++++++++----------------- tests/test_segresnet_ds.py | 86 +++-- 4 files changed, 340 insertions(+), 368 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 249375dfc1..38c59f73a5 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -481,6 +481,11 @@ Nets .. autoclass:: SegResNetDS :members: +`SegResNetDS2` +~~~~~~~~~~~~~~ +.. autoclass:: SegResNetDS2 + :members: + `SegResNetVAE` ~~~~~~~~~~~~~~ .. autoclass:: SegResNetVAE diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 748d73bd62..0570c9fcc1 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -118,9 +118,9 @@ from .unet import UNet, Unet from .unetr import UNETR from .varautoencoder import VarAutoEncoder +from .vista3d import VISTA3D, vista3d132 from .vit import ViT from .vitautoenc import ViTAutoEnc -from .vista3d import VISTA3D, VISTA3D132 from .vnet import VNet from .voxelmorph import VoxelMorph, VoxelMorphUNet from .vqvae import VQVAE diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 5ad63532e9..662f6c369b 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -12,36 +12,38 @@ from __future__ import annotations import math -from typing import Any, Optional, Tuple +from typing import Any, Callable, Optional, Sequence, Tuple -from torch import Tensor, nn import numpy as np import torch import torch.nn.functional as F +from torch import nn import monai -from monai.networks.blocks import UnetrBasicBlock -from monai.networks.blocks import MLPBlock +from monai.networks.blocks import MLPBlock, UnetrBasicBlock from monai.networks.nets import SegResNetDS2 +from monai.transforms.utils import convert_points_to_disc from monai.transforms.utils import get_largest_connected_component_mask_point as lcc -from monai.transforms.utils import convert_points_to_disc, sample_points_from_label +from monai.transforms.utils import sample_points_from_label from monai.utils import optional_import, unsqueeze_left, unsqueeze_right rearrange, _ = optional_import("einops", name="rearrange") -__all__ = ["VISTA3D", "VISTA3D132"] +__all__ = ["VISTA3D", "vista3d132"] class VISTA3D(nn.Module): """ VISTA3D based on `VISTA3D: Versatile Imaging SegmenTation and Annotation model for 3D Computed Tomography https://arxiv.org/abs/2406.05285>`_. + Args: image_encoder: image encoder backbone for feature extraction. class_head: class head used for class index based segmentation point_head: point head used for interactive segmetnation """ - def __init__(self, image_encoder, class_head, point_head): + + def __init__(self, image_encoder: nn.Module, class_head: nn.Module, point_head: nn.Module): super().__init__() self.image_encoder = image_encoder self.class_head = class_head @@ -52,52 +54,64 @@ def __init__(self, image_encoder, class_head, point_head): self.NINF_VALUE = -9999 self.PINF_VALUE = 9999 - def get_bs(self, class_vector, point_coords): - """ Get number of foreground classes based on class and point prompt. - """ + def get_bs(self, class_vector: torch.Tensor | None, point_coords: torch.Tensor | None) -> int: + """Get number of foreground classes based on class and point prompt.""" if class_vector is None: - assert point_coords is not None, "prompt is required" + if point_coords is None: + raise ValueError("class_vector and point_coords cannot be both None.") return point_coords.shape[0] else: return class_vector.shape[0] - def convert_point_label(self, point_label, label_set=None, - special_index=[23, 24, 25, 26, 27, 57, 128]): - """ Convert point label based on its class prompt. For special classes defined in special index, + def convert_point_label( + self, + point_label: torch.Tensor, + label_set: Sequence[int] | None = None, + special_index: Sequence[int] = (23, 24, 25, 26, 27, 57, 128), + ): + """ + Convert point label based on its class prompt. For special classes defined in special index, the positive/negative point label will be converted from 1/0 to 3/2. The purpose is to separate those classes with ambiguous classes. + + Args: + point_label: the point label tensor, [B, N]. + label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID, + this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot + evaluation, this label_set should be the original index. + special_index: the special class index that needs to be converted. """ if label_set is None: return point_label - assert point_label.shape[0] == len(label_set) + if not point_label.shape[0] == len(label_set): + raise ValueError("point_label and label_set must have the same length.") + for i in range(len(label_set)): if label_set[i] in special_index: for j in range(len(point_label[i])): - point_label[i, j] = ( - point_label[i, j] + 2 - if point_label[i, j] > -1 - else point_label[i, j] - ) + point_label[i, j] = point_label[i, j] + 2 if point_label[i, j] > -1 else point_label[i, j] return point_label def sample_points_patch_val( self, - labels, - patch_coords, - label_set, - use_center=True, - mapped_label_set=None, - max_ppoint=1, - max_npoint=0, - **kwargs + labels: torch.Tensor, + patch_coords: Sequence[slice], + label_set: Sequence[int], + use_center: bool = True, + mapped_label_set: Sequence[int] | None = None, + max_ppoint: int = 1, + max_npoint: int = 0, ): - """Sample points for patch during sliding window validation. Only used for point only validation. + """ + Sample points for patch during sliding window validation. Only used for point only validation. + Args: - labels: [1, 1, H, W, D] - patch_coords: sliding window slice object - label_set: local index, must match values in labels - use_center: sample points from the center - mapped_label_set: global index, it is used to identify special classes. + labels: shape [1, 1, H, W, D]. + patch_coords: a sequence of sliding window slice objects. + label_set: local index, must match values in labels. + use_center: sample points from the center. + mapped_label_set: global index, it is used to identify special classes and is the global index + for the sampled points. max_ppoint/max_npoint: positive points and negative points to sample. """ point_coords, point_labels = sample_points_from_label( @@ -109,43 +123,32 @@ def sample_points_patch_val( use_center=use_center, ) point_labels = self.convert_point_label(point_labels, mapped_label_set) - return ( - point_coords, - point_labels, - torch.tensor(mapped_label_set).to(point_coords.device).unsqueeze(-1), - ) + return (point_coords, point_labels, torch.Tensor(label_set).to(point_coords.device).unsqueeze(-1)) + + def update_point_to_patch( + self, patch_coords: Sequence[slice], point_coords: torch.Tensor, point_labels: torch.Tensor + ): + """ + Update point_coords with respect to patch coords. + If point is outside of the patch, remove the coordinates and set label to -1. - def update_point_to_patch(self, patch_coords, point_coords, point_labels): - """ Update point_coords with respect to patch coords. - If point is outside of the patch, remove the coordinates and set label to -1 Args: - patch_coords: the python slice object representing the patch coordinates during sliding window inference. This value is - passed from sliding_window_inferer. + patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference. + This value is passed from sliding_window_inferer. point_coords: point coordinates, [B, N, 3]. point_labels: point labels, [B, N]. """ - patch_ends = [ - patch_coords[-3].stop, - patch_coords[-2].stop, - patch_coords[-1].stop, - ] - patch_starts = [ - patch_coords[-3].start, - patch_coords[-2].start, - patch_coords[-1].start, - ] + patch_ends = [patch_coords[-3].stop, patch_coords[-2].stop, patch_coords[-1].stop] + patch_starts = [patch_coords[-3].start, patch_coords[-2].start, patch_coords[-1].start] # update point coords - patch_starts = unsqueeze_left(torch.tensor(patch_starts, - device=point_coords.device), 2) - patch_ends = unsqueeze_left(torch.tensor(patch_ends, - device=point_coords.device), 2) + patch_starts_tensor = unsqueeze_left(torch.Tensor(patch_starts, device=point_coords.device), 2) + patch_ends_tensor = unsqueeze_left(torch.Tensor(patch_ends, device=point_coords.device), 2) # [1 N 1] indices = torch.logical_and( - ((point_coords - patch_starts) > 0).all(2), - ((patch_ends - point_coords) > 0).all(2), + ((point_coords - patch_starts_tensor) > 0).all(2), ((patch_ends_tensor - point_coords) > 0).all(2) ) # check if it's within patch coords - point_coords = point_coords.clone() - patch_starts + point_coords = point_coords.clone() - patch_starts_tensor point_labels = point_labels.clone() if indices.any(): point_labels[~indices] = -1 @@ -154,54 +157,55 @@ def update_point_to_patch(self, patch_coords, point_coords, point_labels): not_pad_indices = (point_labels != -1).any(0) point_coords = point_coords[:, not_pad_indices] point_labels = point_labels[:, not_pad_indices] - else: - point_coords = None - point_labels = None - return point_coords, point_labels + return point_coords, point_labels + return None, None def connected_components_combine( - self, logits, point_logits, point_coords, point_labels, mapping_index, thred=0.5 + self, + logits: torch.Tensor, + point_logits: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mapping_index: torch.Tensor, + thred: float = 0.5, ): - """ Combine auto results with point click response. The auto results have shape [B, 1, H, W, D] which means B foreground masks - from a single image patch. Out of those B foreground masks, user may add points to a subset of B1 foreground masks for editing. + """ + Combine auto results with point click response. The auto results have shape [B, 1, H, W, D] which means B foreground masks + from a single image patch. + Out of those B foreground masks, user may add points to a subset of B1 foreground masks for editing. mapping_index represents the correspondence between B and B1. For mapping_index with point clicks, NaN values in logits will be replaced with point_logits. Meanwhile, the added/removed - region in point clicks must be updated by the lcc function. Notice, if a positive point is within logits/prev_mask, the components containing the positive point - will be added. + region in point clicks must be updated by the lcc function. + Notice, if a positive point is within logits/prev_mask, the components containing the positive point will be added. + Args: logits: automatic branch results, [B, 1, H, W, D]. point_logits: point branch results, [B1, 1, H, W, D]. point_coords: point coordinates, [B1, N, 3]. point_labels: point labels, [B1, N]. mapping_index: [B]. + thred: the threshold to convert logits to binary. """ - logits = ( - logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits - ) + logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits _logits = logits[mapping_index] inside = [] for i in range(_logits.shape[0]): inside.append( np.any( [ - _logits[i,0,p[0],p[1],p[2]].item() > 0 + _logits[i, 0, p[0], p[1], p[2]].item() > 0 for p in point_coords[i].cpu().numpy().round().astype(int) ] ) ) - inside = torch.tensor(inside).to(logits.device) + inside_tensor = torch.Tensor(inside).to(logits.device) nan_mask = torch.isnan(_logits) # _logits are converted to binary [B1, 1, H, W, D] _logits = torch.nan_to_num(_logits, nan=self.NINF_VALUE).sigmoid() pos_region = point_logits.sigmoid() > thred - diff_pos = torch.logical_and( - torch.logical_or(_logits <= thred, unsqueeze_right(inside, 5)), - pos_region, - ) + diff_pos = torch.logical_and(torch.logical_or(_logits <= thred, unsqueeze_right(inside_tensor, 5)), pos_region) diff_neg = torch.logical_and((_logits > thred), ~pos_region) - cc = lcc( - diff_pos, diff_neg, point_coords=point_coords, point_labels=point_labels - ) + cc = lcc(diff_pos, diff_neg, point_coords=point_coords, point_labels=point_labels) # cc is the region that can be updated by point_logits. cc = cc.to(logits.device) # Need to replace NaN with point_logits. diff_neg will never lie in nan_mask, @@ -218,41 +222,47 @@ def connected_components_combine( return logits def gaussian_combine( - self, logits, point_logits, point_coords, point_labels, mapping_index, radius + self, + logits: torch.Tensor, + point_logits: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mapping_index: torch.Tensor, + radius: int | None = None, ): - """ Combine point results with auto results using gaussian. - Args: - logits: automatic branch results, [B, 1, H, W, D]. - point_logits: point branch results, [B1, 1, H, W, D]. - point_coords: point coordinates, [B1, N, 3]. - point_labels: point labels, [B1, N]. - mapping_index: [B]. - radius: gaussian ball radius. + """ + Combine point results with auto results using gaussian. + + Args: + logits: automatic branch results, [B, 1, H, W, D]. + point_logits: point branch results, [B1, 1, H, W, D]. + point_coords: point coordinates, [B1, N, 3]. + point_labels: point labels, [B1, N]. + mapping_index: [B]. + radius: gaussian ball radius. """ if radius is None: radius = min(point_logits.shape[-3:]) // 5 # empirical value 5 - weight = 1 - convert_points_to_disc( - point_logits.shape[-3:], point_coords, point_labels, radius=radius - ).sum(1, keepdims=True) - weight[weight < 0] = 0 - logits = ( - logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits + weight = 1 - convert_points_to_disc(point_logits.shape[-3:], point_coords, point_labels, radius=radius).sum( + 1, keepdims=True ) + weight[weight < 0] = 0 + logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits logits[mapping_index] *= weight logits[mapping_index] += (1 - weight) * point_logits return logits - def set_auto_grad(self, auto_freeze=False, point_freeze=False): - """Freeze auto-branch or point-branch. + def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False): + """ + Freeze auto-branch or point-branch. + Args: - auto_freeze: freeze the auto branch. - point_freeze: freeze the point branch. + auto_freeze: whether to freeze the auto branch. + point_freeze: whether to freeze the point branch. """ if auto_freeze != self.auto_freeze: if hasattr(self.image_encoder, "set_auto_grad"): - self.image_encoder.set_auto_grad( - auto_freeze=auto_freeze, point_freeze=point_freeze - ) + self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) else: for param in self.image_encoder.parameters(): param.requires_grad = (not auto_freeze) and (not point_freeze) @@ -262,9 +272,7 @@ def set_auto_grad(self, auto_freeze=False, point_freeze=False): if point_freeze != self.point_freeze: if hasattr(self.image_encoder, "set_auto_grad"): - self.image_encoder.set_auto_grad( - auto_freeze=auto_freeze, point_freeze=point_freeze - ) + self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) else: for param in self.image_encoder.parameters(): param.requires_grad = (not auto_freeze) and (not point_freeze) @@ -274,23 +282,24 @@ def set_auto_grad(self, auto_freeze=False, point_freeze=False): def forward( self, - input_images, - point_coords=None, - point_labels=None, - class_vector=None, - prompt_class=None, - patch_coords=None, - labels=None, - label_set=None, - prev_mask=None, - radius=None, - val_point_sampler=None, + input_images: torch.Tensor, + point_coords: torch.Tensor | None = None, + point_labels: torch.Tensor | None = None, + class_vector: torch.Tensor | None = None, + prompt_class: torch.Tensor | None = None, + patch_coords: Sequence[slice] | None = None, + labels: torch.Tensor | None = None, + label_set: Sequence[int] | None = None, + prev_mask: torch.Tensor | None = None, + radius: int | None = None, + val_point_sampler: Callable | None = None, **kwargs, ): """ The forward function for VISTA3D. We only support single patch in training and inference. One exception is allowing sliding window batch size > 1 for automatic segmentation only case. B represents number of objects, N represents number of points for each objects. + Args: input_images: [1, 1, H, W, D] point_coords: [B, N, 3] @@ -301,16 +310,17 @@ def forward( the points are for zero-shot or supported class. When class_vector and point_coords are both provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b] will be considered novel class. - patch_coords: the python slice object representing the patch coordinates during sliding window inference. This value is - passed from sliding_window_inferer. This is an indicator for training phase or validation phase. + patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference. + This value is passed from sliding_window_inferer. This is an indicator for training phase or validation phase. labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID, this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot evaluation, this label_set should be the original index. - prev_mask: [B, N, H_fullsize, W_fullsize, D_fullsize]. This is the transposed raw output from sliding_window_inferer before - any postprocessing. When user click points to perform auto-results correction, this can be the auto-results. - radius: single float value controling the gaussian blur when combining point and auto results. The gaussian combine is not used - in VISTA3D training but might be useful for finetuning purposes. + prev_mask: [B, N, H_fullsize, W_fullsize, D_fullsize]. + This is the transposed raw output from sliding_window_inferer before any postprocessing. + When user click points to perform auto-results correction, this can be the auto-results. + radius: single float value controling the gaussian blur when combining point and auto results. + The gaussian combine is not used in VISTA3D training but might be useful for finetuning purposes. val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation. """ @@ -319,25 +329,25 @@ def forward( if point_coords is None and class_vector is None: return self.NINF_VALUE + torch.zeros([1, 1, *image_size], device=device) + if point_coords is not None and point_labels is None: + raise ValueError("point_labels must be provided when point_coords is provided.") + bs = self.get_bs(class_vector, point_coords) if patch_coords is not None: # if during validation and perform enable based point-validation. if labels is not None and label_set is not None: # if labels is not None, sample from labels for each patch. if val_point_sampler is None: + # TODO: think about how to refactor this part. val_point_sampler = self.sample_points_patch_val - point_coords, point_labels, prompt_class = val_point_sampler( - labels, patch_coords, label_set - ) - if prompt_class[0].item() == 0: - point_labels[0] = -1 + point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords, label_set) + if prompt_class[0].item() == 0: # type: ignore + point_labels[0] = -1 # type: ignore labels, prev_mask = None, None elif point_coords is not None: # If not performing patch-based point only validation, use user provided click points for inference. # the point clicks is in original image space, convert it to current patch-coordinate space. - point_coords, point_labels = self.update_point_to_patch( - patch_coords, point_coords, point_labels - ) + point_coords, point_labels = self.update_point_to_patch(patch_coords, point_coords, point_labels) # type: ignore if point_coords is not None and point_labels is not None: # remove points that used for padding purposes (point_label = -1) @@ -358,55 +368,39 @@ def forward( if point_coords is None and class_vector is None: return self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device) - if ( - self.image_embeddings is not None - and kwargs.get("keep_cache", False) - and class_vector is None - ): + if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None: out, out_auto = self.image_embeddings, None else: out, out_auto = self.image_encoder( - input_images, - with_point=point_coords is not None, - with_label=class_vector is not None, + input_images, with_point=point_coords is not None, with_label=class_vector is not None ) - input_images = None + # release memory + input_images = None # type: ignore # force releasing memories that set to None torch.cuda.empty_cache() if class_vector is not None: logits, _ = self.class_head(out_auto, class_vector) if point_coords is not None: - point_logits = self.point_head( - out, point_coords, point_labels, class_vector=prompt_class - ) + point_logits = self.point_head(out, point_coords, point_labels, class_vector=prompt_class) if patch_coords is None: logits = self.gaussian_combine( - logits, - point_logits, - point_coords, - point_labels, - mapping_index, - radius, + logits, point_logits, point_coords, point_labels, mapping_index, radius # type: ignore ) else: # during validation use largest component logits = self.connected_components_combine( - logits, point_logits, point_coords, point_labels, mapping_index + logits, point_logits, point_coords, point_labels, mapping_index # type: ignore ) else: - logits = self.NINF_VALUE + torch.zeros( - [bs, 1, *image_size], device=device, dtype=out.dtype - ) - logits[mapping_index] = self.point_head( - out, point_coords, point_labels, class_vector=prompt_class - ) + logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device, dtype=out.dtype) + logits[mapping_index] = self.point_head(out, point_coords, point_labels, class_vector=prompt_class) if prev_mask is not None and patch_coords is not None: logits = self.connected_components_combine( prev_mask[patch_coords].transpose(1, 0).to(logits.device), logits[mapping_index], - point_coords, - point_labels, + point_coords, # type: ignore + point_labels, # type: ignore mapping_index, ) @@ -414,83 +408,49 @@ def forward( self.image_embeddings = out.detach() return logits + class PointMappingSAM(nn.Module): def __init__( self, - feature_size, - max_prompt=32, - num_add_mask_tokens=2, - n_classes=512, - last_supported=132, + feature_size: int, + max_prompt: int = 32, + num_add_mask_tokens: int = 2, + n_classes: int = 512, + last_supported: int = 132, ): super().__init__() transformer_dim = feature_size self.max_prompt = max_prompt self.feat_downsample = nn.Sequential( - nn.Conv3d( - in_channels=feature_size, - out_channels=feature_size, - kernel_size=3, - stride=2, - padding=1, - ), + nn.Conv3d(in_channels=feature_size, out_channels=feature_size, kernel_size=3, stride=2, padding=1), nn.InstanceNorm3d(feature_size), nn.GELU(), - nn.Conv3d( - in_channels=feature_size, - out_channels=transformer_dim, - kernel_size=3, - stride=1, - padding=1, - ), + nn.Conv3d(in_channels=feature_size, out_channels=transformer_dim, kernel_size=3, stride=1, padding=1), nn.InstanceNorm3d(feature_size), ) - self.mask_downsample = nn.Conv3d( - in_channels=2, out_channels=2, kernel_size=3, stride=2, padding=1 - ) + self.mask_downsample = nn.Conv3d(in_channels=2, out_channels=2, kernel_size=3, stride=2, padding=1) - self.transformer = TwoWayTransformer( - depth=2, - embedding_dim=transformer_dim, - mlp_dim=512, - num_heads=4, - ) + self.transformer = TwoWayTransformer(depth=2, embedding_dim=transformer_dim, mlp_dim=512, num_heads=4) self.pe_layer = PositionEmbeddingRandom(transformer_dim // 2) - self.point_embeddings = nn.ModuleList( - [nn.Embedding(1, transformer_dim), nn.Embedding(1, transformer_dim)] - ) + self.point_embeddings = nn.ModuleList([nn.Embedding(1, transformer_dim), nn.Embedding(1, transformer_dim)]) self.not_a_point_embed = nn.Embedding(1, transformer_dim) self.special_class_embed = nn.Embedding(1, transformer_dim) self.mask_tokens = nn.Embedding(1, transformer_dim) self.output_upscaling = nn.Sequential( - nn.ConvTranspose3d( - transformer_dim, - transformer_dim, - kernel_size=3, - stride=2, - padding=1, - output_padding=1, - ), + nn.ConvTranspose3d(transformer_dim, transformer_dim, kernel_size=3, stride=2, padding=1, output_padding=1), nn.InstanceNorm3d(transformer_dim), nn.GELU(), - nn.Conv3d( - transformer_dim, transformer_dim, kernel_size=3, stride=1, padding=1 - ), + nn.Conv3d(transformer_dim, transformer_dim, kernel_size=3, stride=1, padding=1), ) - self.output_hypernetworks_mlps = MLP( - transformer_dim, transformer_dim, transformer_dim, 3 - ) + self.output_hypernetworks_mlps = MLP(transformer_dim, transformer_dim, transformer_dim, 3) - ## MultiMask output + # MultiMask output self.num_add_mask_tokens = num_add_mask_tokens self.output_add_hypernetworks_mlps = nn.ModuleList( - [ - MLP(transformer_dim, transformer_dim, transformer_dim, 3) - for i in range(self.num_add_mask_tokens) - ] + [MLP(transformer_dim, transformer_dim, transformer_dim, 3) for i in range(self.num_add_mask_tokens)] ) # class embedding self.n_classes = n_classes @@ -499,38 +459,37 @@ def __init__( self.zeroshot_embed = nn.Embedding(1, transformer_dim) self.supported_embed = nn.Embedding(1, transformer_dim) - def forward(self, out, point_coords, point_labels, class_vector=None): + def forward( + self, + out: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + class_vector: torch.Tensor | None = None, + ): # downsample out out_low = self.feat_downsample(out) - out_shape = out.shape[-3:] - out = None + out_shape = tuple(out.shape[-3:]) + # release memory + out = None # type: ignore torch.cuda.empty_cache() # embed points points = point_coords + 0.5 # Shift to center of pixel - point_embedding = self.pe_layer.forward_with_coords(points, out_shape) + point_embedding = self.pe_layer.forward_with_coords(points, out_shape) # type: ignore point_embedding[point_labels == -1] = 0.0 point_embedding[point_labels == -1] += self.not_a_point_embed.weight point_embedding[point_labels == 0] += self.point_embeddings[0].weight point_embedding[point_labels == 1] += self.point_embeddings[1].weight - point_embedding[point_labels == 2] += ( - self.point_embeddings[0].weight + self.special_class_embed.weight - ) - point_embedding[point_labels == 3] += ( - self.point_embeddings[1].weight + self.special_class_embed.weight - ) + point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight + point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight output_tokens = self.mask_tokens.weight - output_tokens = output_tokens.unsqueeze(0).expand( - point_embedding.size(0), -1, -1 - ) + output_tokens = output_tokens.unsqueeze(0).expand(point_embedding.size(0), -1, -1) if class_vector is None: tokens_all = torch.cat( ( output_tokens, point_embedding, - self.supported_embed.weight.unsqueeze(0).expand( - point_embedding.size(0), -1, -1 - ), + self.supported_embed.weight.unsqueeze(0).expand(point_embedding.size(0), -1, -1), ), dim=1, ) @@ -542,10 +501,7 @@ def forward(self, out, point_coords, point_labels, class_vector=None): class_embeddings.append(self.zeroshot_embed.weight) else: class_embeddings.append(self.supported_embed.weight) - class_embeddings = torch.stack(class_embeddings) - tokens_all = torch.cat( - (output_tokens, point_embedding, class_embeddings), dim=1 - ) + tokens_all = torch.cat((output_tokens, point_embedding, torch.stack(class_embeddings)), dim=1) # cross attention masks = [] max_prompt = self.max_prompt @@ -556,24 +512,22 @@ def forward(self, out, point_coords, point_labels, class_vector=None): idx = (i * max_prompt, min((i + 1) * max_prompt, tokens_all.shape[0])) tokens = tokens_all[idx[0] : idx[1]] src = torch.repeat_interleave(out_low, tokens.shape[0], dim=0) - pos_src = torch.repeat_interleave( - self.pe_layer(out_low.shape[-3:]).unsqueeze(0), tokens.shape[0], dim=0 - ) + pos_src = torch.repeat_interleave(self.pe_layer(out_low.shape[-3:]).unsqueeze(0), tokens.shape[0], dim=0) b, c, h, w, d = src.shape hs, src = self.transformer(src, pos_src, tokens) mask_tokens_out = hs[:, :1, :] hyper_in = self.output_hypernetworks_mlps(mask_tokens_out) - src = src.transpose(1, 2).view(b, c, h, w, d) + src = src.transpose(1, 2).view(b, c, h, w, d) # type: ignore upscaled_embedding = self.output_upscaling(src) b, c, h, w, d = upscaled_embedding.shape mask = hyper_in @ upscaled_embedding.view(b, c, h * w * d) masks.append(mask.view(-1, 1, h, w, d)) - masks = torch.vstack(masks) - return masks + + return torch.vstack(masks) class ClassMappingClassify(nn.Module): - def __init__(self, n_classes, feature_size, use_mlp=False): + def __init__(self, n_classes: int, feature_size: int, use_mlp: bool = False): super().__init__() self.use_mlp = use_mlp if use_mlp: @@ -605,7 +559,7 @@ def __init__(self, n_classes, feature_size, use_mlp=False): ), ) - def forward(self, src, class_vector): + def forward(self, src: torch.Tensor, class_vector: torch.Tensor): b, c, h, w, d = src.shape src = self.image_post_mapping(src) class_embedding = self.class_embeddings(class_vector) @@ -614,12 +568,21 @@ def forward(self, src, class_vector): # [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension. masks = [] for i in range(b): - mask = (class_embedding @ src[[i]].view(1, c, h * w * d)) + mask = class_embedding @ src[[i]].view(1, c, h * w * d) masks.append(mask.view(-1, 1, h, w, d)) - masks = torch.cat(masks, 1) - return masks, class_embedding -def VISTA3D132(encoder_embed_dim=48, in_channels=1): + return torch.cat(masks, 1), class_embedding + + +def vista3d132(encoder_embed_dim: int = 48, in_channels: int = 1): + """ + Exact VISTA3D network configuration used in https://arxiv.org/abs/2406.05285>`_. + The model treats class index larger than 132 as zero-shot. + + Args: + encoder_embed_dim: hidden dimension for encoder. + in_channels: input channel number. + """ segresnet = SegResNetDS2( in_channels=in_channels, blocks_down=(1, 2, 2, 4, 4), @@ -633,24 +596,6 @@ def VISTA3D132(encoder_embed_dim=48, in_channels=1): vista = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head) return vista -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Adapted from https://github.com/facebookresearch/segment-anything -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - class TwoWayTransformer(nn.Module): def __init__( @@ -665,14 +610,15 @@ def __init__( """ A transformer decoder that attends to an input image using queries whose positional embedding is supplied. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`. Args: - depth (int): number of layers in the transformer - embedding_dim (int): the channel dimension for the input embeddings - num_heads (int): the number of heads for multihead attention. Must - divide embedding_dim - mlp_dim (int): the channel dimension internal to the MLP block - activation (nn.Module): the activation to use in the MLP block + depth: number of layers in the transformer. + embedding_dim: the channel dimension for the input embeddings. + num_heads: the number of heads for multihead attention. Must divide embedding_dim. + mlp_dim: the channel dimension internal to the MLP block. + activation: the activation to use in the MLP block. + attention_downsample_rate: the rate at which to downsample the image before projecting. """ super().__init__() self.depth = depth @@ -693,32 +639,26 @@ def __init__( ) ) - self.final_attn_token_to_image = Attention( - embedding_dim, num_heads, downsample_rate=attention_downsample_rate - ) + self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) self.norm_final_attn = nn.LayerNorm(embedding_dim) def forward( - self, - image_embedding: Tensor, - image_pe: Tensor, - point_embedding: Tensor, - ) -> Tuple[Tensor, Tensor]: + self, image_embedding: torch.Tensor, image_pe: torch.Tensor, point_embedding: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: - image_embedding (torch.Tensor): image to attend to. Should be shape - B x embedding_dim x h x w for any h and w. - image_pe (torch.Tensor): the positional encoding to add to the image. Must - have the same shape as image_embedding. - point_embedding (torch.Tensor): the embedding to add to the query points. - Must have shape B x N_points x embedding_dim for any N_points. + image_embedding: image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe: the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding: the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. Returns: - torch.Tensor: the processed point_embedding - torch.Tensor: the processed image_embedding + torch.torch.Tensor: the processed point_embedding. + torch.torch.Tensor: the processed image_embedding. """ # BxCxHxW -> BxHWxC == B x N_image_tokens x C - bs, c, h, w, d = image_embedding.shape image_embedding = image_embedding.flatten(2).permute(0, 2, 1) image_pe = image_pe.flatten(2).permute(0, 2, 1) @@ -728,12 +668,7 @@ def forward( # Apply transformer blocks and final layernorm for layer in self.layers: - queries, keys = layer( - queries=queries, - keys=keys, - query_pe=point_embedding, - key_pe=image_pe, - ) + queries, keys = layer(queries=queries, keys=keys, query_pe=point_embedding, key_pe=image_pe) # Apply the final attention layer from the points to the image q = queries + point_embedding @@ -760,36 +695,33 @@ def __init__( inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse inputs. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`. - Arguments: - embedding_dim (int): the channel dimension of the embeddings - num_heads (int): the number of heads in the attention layers - mlp_dim (int): the hidden dimension of the mlp block - activation (nn.Module): the activation of the mlp block - skip_first_layer_pe (bool): skip the PE on the first layer + Args: + embedding_dim: the channel dimension of the embeddings. + num_heads: the number of heads in the attention layers. + mlp_dim: the hidden dimension of the mlp block. + activation: the activation of the mlp block. + skip_first_layer_pe: skip the PE on the first layer. """ super().__init__() self.self_attn = Attention(embedding_dim, num_heads) self.norm1 = nn.LayerNorm(embedding_dim) - self.cross_attn_token_to_image = Attention( - embedding_dim, num_heads, downsample_rate=attention_downsample_rate - ) + self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) self.norm2 = nn.LayerNorm(embedding_dim) self.mlp = MLPBlock(hidden_size=embedding_dim, mlp_dim=mlp_dim, act=activation, dropout_mode="vista3d") self.norm3 = nn.LayerNorm(embedding_dim) self.norm4 = nn.LayerNorm(embedding_dim) - self.cross_attn_image_to_token = Attention( - embedding_dim, num_heads, downsample_rate=attention_downsample_rate - ) + self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) self.skip_first_layer_pe = skip_first_layer_pe def forward( - self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor - ) -> Tuple[Tensor, Tensor]: + self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor, key_pe: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: # Self attention block if self.skip_first_layer_pe: queries = self.self_attn(q=queries, k=queries, v=queries) @@ -825,38 +757,40 @@ class Attention(nn.Module): """ An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and values. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`. + + Args: + embedding_dim: the channel dimension of the embeddings. + num_heads: the number of heads in the attention layers. + downsample_rate: the rate at which to downsample the image before projecting. """ - def __init__( - self, - embedding_dim: int, - num_heads: int, - downsample_rate: int = 1, - ) -> None: + def __init__(self, embedding_dim: int, num_heads: int, downsample_rate: int = 1) -> None: super().__init__() self.embedding_dim = embedding_dim self.internal_dim = embedding_dim // downsample_rate self.num_heads = num_heads - assert ( - self.internal_dim % num_heads == 0 - ), "num_heads must divide embedding_dim." + if not self.internal_dim % num_heads == 0: + raise ValueError("num_heads must divide embedding_dim.") self.q_proj = nn.Linear(embedding_dim, self.internal_dim) self.k_proj = nn.Linear(embedding_dim, self.internal_dim) self.v_proj = nn.Linear(embedding_dim, self.internal_dim) self.out_proj = nn.Linear(self.internal_dim, embedding_dim) - def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + def _separate_heads(self, x: torch.Tensor, num_heads: int) -> torch.Tensor: b, n, c = x.shape x = x.reshape(b, n, num_heads, c // num_heads) - return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + # B x N_heads x N_tokens x C_per_head + return x.transpose(1, 2) - def _recombine_heads(self, x: Tensor) -> Tensor: + def _recombine_heads(self, x: torch.Tensor) -> torch.Tensor: b, n_heads, n_tokens, c_per_head = x.shape x = x.transpose(1, 2) - return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + # B x N_tokens x C + return x.reshape(b, n_tokens, n_heads * c_per_head) - def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: # Input projections q = self.q_proj(q) k = self.k_proj(k) @@ -884,18 +818,20 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: class PositionEmbeddingRandom(nn.Module): """ Positional encoding using random spatial frequencies. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py`. + + Args: + num_pos_feats: the number of positional encoding features. + scale: the scale of the positional encoding. """ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: super().__init__() if scale is None or scale <= 0.0: scale = 1.0 - self.register_buffer( - "positional_encoding_gaussian_matrix", - scale * torch.randn((3, num_pos_feats)), - ) + self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((3, num_pos_feats))) - def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + def _pe_encoding(self, coords: torch.torch.Tensor) -> torch.torch.Tensor: """Positionally encode points that are normalized to [0,1].""" # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape coords = 2 * coords - 1 @@ -907,7 +843,7 @@ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: # [bs=1, N=2, 128+128=256] return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) - def forward(self, size: Tuple[int, int, int]) -> torch.Tensor: + def forward(self, size: Tuple[int, int, int]) -> torch.torch.Tensor: """Generate positional encoding for a grid of the specified size.""" h, w, d = size device: Any = self.positional_encoding_gaussian_matrix.device @@ -919,37 +855,44 @@ def forward(self, size: Tuple[int, int, int]) -> torch.Tensor: y_embed = y_embed / w z_embed = z_embed / d pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1)) - return pe.permute(3, 0, 1, 2) # C x H x W + # C x H x W + return pe.permute(3, 0, 1, 2) def forward_with_coords( - self, coords_input: torch.Tensor, image_size: Tuple[int, int] - ) -> torch.Tensor: + self, coords_input: torch.torch.Tensor, image_size: Tuple[int, int, int] + ) -> torch.torch.Tensor: """Positionally encode points that are not normalized to [0,1].""" coords = coords_input.clone() coords[:, :, 0] = coords[:, :, 0] / image_size[0] coords[:, :, 1] = coords[:, :, 1] / image_size[1] coords[:, :, 2] = coords[:, :, 2] / image_size[2] - return self._pe_encoding(coords.to(torch.float)) # B x N x C + # B x N x C + return self._pe_encoding(coords.to(torch.float)) class MLP(nn.Module): + """ + Multi-layer perceptron. This class is only used for `PointMappingSAM`. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`. + + Args: + input_dim: the input dimension. + hidden_dim: the hidden dimension. + output_dim: the output dimension. + num_layers: the number of layers. + sigmoid_output: whether to apply a sigmoid activation to the output. + """ + def __init__( - self, - input_dim: int, - hidden_dim: int, - output_dim: int, - num_layers: int, - sigmoid_output: bool = False, + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False ) -> None: super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) - self.layers = nn.ModuleList( - nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) - ) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) self.sigmoid_output = sigmoid_output - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) if self.sigmoid_output: diff --git a/tests/test_segresnet_ds.py b/tests/test_segresnet_ds.py index 5372fcc8ae..eab7bac9a0 100644 --- a/tests/test_segresnet_ds.py +++ b/tests/test_segresnet_ds.py @@ -17,7 +17,7 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.nets import SegResNetDS +from monai.networks.nets import SegResNetDS, SegResNetDS2 from tests.utils import SkipIfBeforePyTorchVersion, test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -71,7 +71,7 @@ ] -class TestResNetDS(unittest.TestCase): +class TestSegResNetDS(unittest.TestCase): @parameterized.expand(TEST_CASE_SEGRESNET_DS) def test_shape(self, input_param, input_shape, expected_shape): @@ -80,47 +80,71 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + @parameterized.expand(TEST_CASE_SEGRESNET_DS) + def test_shape_ds2(self, input_param, input_shape, expected_shape): + net = SegResNetDS2(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device), with_label=False) + self.assertEqual(result[0].shape, expected_shape, msg=str(input_param)) + self.assertTrue(result[1] == []) + + result = net(torch.randn(input_shape).to(device), with_point=False) + self.assertEqual(result[1].shape, expected_shape, msg=str(input_param)) + self.assertTrue(result[0] == []) + @parameterized.expand(TEST_CASE_SEGRESNET_DS2) def test_shape2(self, input_param, input_shape, expected_shape): dsdepth = input_param.get("dsdepth", 1) - net = SegResNetDS(**input_param).to(device) - - net.train() - result = net(torch.randn(input_shape).to(device)) - if dsdepth > 1: - assert isinstance(result, list) - self.assertEqual(dsdepth, len(result)) - for i in range(dsdepth): - self.assertEqual( - result[i].shape, - expected_shape[:2] + tuple(e // (2**i) for e in expected_shape[2:]), - msg=str(input_param), - ) - else: - assert isinstance(result, torch.Tensor) - self.assertEqual(result.shape, expected_shape, msg=str(input_param)) - - net.eval() - result = net(torch.randn(input_shape).to(device)) - assert isinstance(result, torch.Tensor) - self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + for net in [SegResNetDS, SegResNetDS2]: + net = net(**input_param).to(device) + net.train() + if isinstance(net, SegResNetDS2): + result = net(torch.randn(input_shape).to(device), with_label=False)[0] + else: + result = net(torch.randn(input_shape).to(device)) + if dsdepth > 1: + assert isinstance(result, list) + self.assertEqual(dsdepth, len(result)) + for i in range(dsdepth): + self.assertEqual( + result[i].shape, + expected_shape[:2] + tuple(e // (2**i) for e in expected_shape[2:]), + msg=str(input_param), + ) + else: + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + if not isinstance(net, SegResNetDS2): + # eval mode of SegResNetDS2 has same output as training mode + # so only test eval mode for SegResNetDS + net.eval() + result = net(torch.randn(input_shape).to(device)) + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) @parameterized.expand(TEST_CASE_SEGRESNET_DS3) def test_shape3(self, input_param, input_shape, expected_shapes): dsdepth = input_param.get("dsdepth", 1) - net = SegResNetDS(**input_param).to(device) - - net.train() - result = net(torch.randn(input_shape).to(device)) - assert isinstance(result, list) - self.assertEqual(dsdepth, len(result)) - for i in range(dsdepth): - self.assertEqual(result[i].shape, expected_shapes[i], msg=str(input_param)) + for net in [SegResNetDS, SegResNetDS2]: + net = net(**input_param).to(device) + net.train() + if isinstance(net, SegResNetDS2): + result = net(torch.randn(input_shape).to(device), with_point=False)[1] + else: + result = net(torch.randn(input_shape).to(device)) + assert isinstance(result, list) + self.assertEqual(dsdepth, len(result)) + for i in range(dsdepth): + self.assertEqual(result[i].shape, expected_shapes[i], msg=str(input_param)) def test_ill_arg(self): with self.assertRaises(ValueError): SegResNetDS(spatial_dims=4) + with self.assertRaises(ValueError): + SegResNetDS2(spatial_dims=4) + @SkipIfBeforePyTorchVersion((1, 10)) def test_script(self): input_param, input_shape, _ = TEST_CASE_SEGRESNET_DS[0] From 599354298f24fd42562bd2cdb1ffcf39a794367f Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 12 Aug 2024 21:42:09 +0800 Subject: [PATCH 19/30] add test Signed-off-by: Yiheng Wang --- tests/test_vista3d.py | 56 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 tests/test_vista3d.py diff --git a/tests/test_vista3d.py b/tests/test_vista3d.py new file mode 100644 index 0000000000..0fa840adf2 --- /dev/null +++ b/tests/test_vista3d.py @@ -0,0 +1,56 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import VISTA3D, SegResNetDS2 +from monai.networks.nets.vista3d import ClassMappingClassify, PointMappingSAM + +device = "cuda" if torch.cuda.is_available() else "cpu" + +TEST_CASES = [ + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + {}, + (1, 1, 64, 64, 64), + (1, 1, 64, 64, 64), + ] +] + + +class TestVista3d(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_vista3d_shape(self, args, input_params, input_shape, expected_shape): + segresnet = SegResNetDS2( + in_channels=args["in_channels"], + blocks_down=(1, 2, 2, 4, 4), + norm="instance", + out_channels=args["encoder_embed_dim"], + init_filters=args["encoder_embed_dim"], + dsdepth=1, + ) + point_head = PointMappingSAM(feature_size=args["encoder_embed_dim"], n_classes=512, last_supported=132) + class_head = ClassMappingClassify(n_classes=512, feature_size=args["encoder_embed_dim"], use_mlp=True) + net = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() From 6c46d47a9adee8983180b5979da5324bb9eac74a Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 12 Aug 2024 22:24:22 +0800 Subject: [PATCH 20/30] fix Tensor tensor issue Signed-off-by: Yiheng Wang --- monai/networks/nets/vista3d.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 662f6c369b..f61e433a52 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -123,7 +123,7 @@ def sample_points_patch_val( use_center=use_center, ) point_labels = self.convert_point_label(point_labels, mapped_label_set) - return (point_coords, point_labels, torch.Tensor(label_set).to(point_coords.device).unsqueeze(-1)) + return (point_coords, point_labels, torch.tensor(label_set).to(point_coords.device).unsqueeze(-1)) def update_point_to_patch( self, patch_coords: Sequence[slice], point_coords: torch.Tensor, point_labels: torch.Tensor @@ -141,8 +141,8 @@ def update_point_to_patch( patch_ends = [patch_coords[-3].stop, patch_coords[-2].stop, patch_coords[-1].stop] patch_starts = [patch_coords[-3].start, patch_coords[-2].start, patch_coords[-1].start] # update point coords - patch_starts_tensor = unsqueeze_left(torch.Tensor(patch_starts, device=point_coords.device), 2) - patch_ends_tensor = unsqueeze_left(torch.Tensor(patch_ends, device=point_coords.device), 2) + patch_starts_tensor = unsqueeze_left(torch.tensor(patch_starts, device=point_coords.device), 2) + patch_ends_tensor = unsqueeze_left(torch.tensor(patch_ends, device=point_coords.device), 2) # [1 N 1] indices = torch.logical_and( ((point_coords - patch_starts_tensor) > 0).all(2), ((patch_ends_tensor - point_coords) > 0).all(2) @@ -198,7 +198,7 @@ def connected_components_combine( ] ) ) - inside_tensor = torch.Tensor(inside).to(logits.device) + inside_tensor = torch.tensor(inside).to(logits.device) nan_mask = torch.isnan(_logits) # _logits are converted to binary [B1, 1, H, W, D] _logits = torch.nan_to_num(_logits, nan=self.NINF_VALUE).sigmoid() @@ -655,8 +655,8 @@ def forward( Must have shape B x N_points x embedding_dim for any N_points. Returns: - torch.torch.Tensor: the processed point_embedding. - torch.torch.Tensor: the processed image_embedding. + torch.Tensor: the processed point_embedding. + torch.Tensor: the processed image_embedding. """ # BxCxHxW -> BxHWxC == B x N_image_tokens x C image_embedding = image_embedding.flatten(2).permute(0, 2, 1) From 979bbe7b4237c0b7bd45c3fec9965486dec204d3 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 12 Aug 2024 22:32:31 +0800 Subject: [PATCH 21/30] add vista3d doc Signed-off-by: Yiheng Wang --- docs/source/networks.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 38c59f73a5..1810fec49b 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -561,6 +561,11 @@ Nets .. autoclass:: UNETR :members: +`VISTA3D` +~~~~~~~~~ +.. autoclass:: VISTA3D + :members: + `SwinUNETR` ~~~~~~~~~~~ .. autoclass:: SwinUNETR From b8a95d04bd85d62903fae3b4cf027cb97e657ebb Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 12 Aug 2024 22:37:42 +0800 Subject: [PATCH 22/30] remove unnecessary check Signed-off-by: Yiheng Wang --- monai/networks/nets/vista3d.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index f61e433a52..d528cc59a6 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -329,9 +329,6 @@ def forward( if point_coords is None and class_vector is None: return self.NINF_VALUE + torch.zeros([1, 1, *image_size], device=device) - if point_coords is not None and point_labels is None: - raise ValueError("point_labels must be provided when point_coords is provided.") - bs = self.get_bs(class_vector, point_coords) if patch_coords is not None: # if during validation and perform enable based point-validation. From 65828b7ffd09d7707eaed1a1d2eb5af44a76adeb Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Mon, 12 Aug 2024 18:29:32 -0400 Subject: [PATCH 23/30] Add test case and fix bug Signed-off-by: heyufan1995 --- monai/transforms/utils.py | 4 ++-- tests/test_vista3d.py | 32 +++++++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e32bf6fc48..363fce91be 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1274,10 +1274,10 @@ def convert_points_to_disc( coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6) coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1) for b, n in np.ndindex(*point.shape[:2]): - point_bn = unsqueeze_right(point[b, n], 6) + point_bn = unsqueeze_right(point[b, n], 4) if point_label[b, n] > -1: channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1 - pow_diff = torch.pow(coords[b, channel] - point_bn[b, n], 2) + pow_diff = torch.pow(coords[b, channel] - point_bn, 2) if disc: masks[b, channel] += pow_diff.sum(0) < radius**2 else: diff --git a/tests/test_vista3d.py b/tests/test_vista3d.py index 0fa840adf2..80ba6ec54f 100644 --- a/tests/test_vista3d.py +++ b/tests/test_vista3d.py @@ -28,6 +28,33 @@ {}, (1, 1, 64, 64, 64), (1, 1, 64, 64, 64), + ], + [ + {"encoder_embed_dim": 48, "in_channels": 2}, + {}, + (1, 2, 64, 64, 64), + (1, 1, 64, 64, 64), + ], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + {'class_vector': torch.tensor([1,2,3], device=device)}, + (1, 1, 64, 64, 64), + (3, 1, 64, 64, 64), + ], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + {'point_coords': torch.tensor([[[1,2,3],[1,2,3]]], device=device), + 'point_labels':torch.tensor([[1,0]], device=device)}, + (1, 1, 64, 64, 64), + (1, 1, 64, 64, 64), + ], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + {'class_vector': torch.tensor([1,2], device=device), + 'point_coords': torch.tensor([[[1,2,3],[1,2,3]],[[1,2,3],[1,2,3]]], device=device), + 'point_labels':torch.tensor([[1,0],[1,0]], device=device)}, + (1, 1, 64, 64, 64), + (2, 1, 64, 64, 64), ] ] @@ -48,7 +75,10 @@ def test_vista3d_shape(self, args, input_params, input_shape, expected_shape): class_head = ClassMappingClassify(n_classes=512, feature_size=args["encoder_embed_dim"], use_mlp=True) net = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head).to(device) with eval_mode(net): - result = net.forward(torch.randn(input_shape).to(device)) + result = net.forward(torch.randn(input_shape).to(device), + point_coords=input_params.get('point_coords', None), + point_labels=input_params.get('point_labels', None), + class_vector=input_params.get('class_vector', None)) self.assertEqual(result.shape, expected_shape) From 26ccfad2b2530de6c19258ac221996c6e7b94826 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Aug 2024 22:30:13 +0000 Subject: [PATCH 24/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_vista3d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_vista3d.py b/tests/test_vista3d.py index 80ba6ec54f..9472d894b3 100644 --- a/tests/test_vista3d.py +++ b/tests/test_vista3d.py @@ -50,7 +50,7 @@ ], [ {"encoder_embed_dim": 48, "in_channels": 1}, - {'class_vector': torch.tensor([1,2], device=device), + {'class_vector': torch.tensor([1,2], device=device), 'point_coords': torch.tensor([[[1,2,3],[1,2,3]],[[1,2,3],[1,2,3]]], device=device), 'point_labels':torch.tensor([[1,0],[1,0]], device=device)}, (1, 1, 64, 64, 64), @@ -75,7 +75,7 @@ def test_vista3d_shape(self, args, input_params, input_shape, expected_shape): class_head = ClassMappingClassify(n_classes=512, feature_size=args["encoder_embed_dim"], use_mlp=True) net = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head).to(device) with eval_mode(net): - result = net.forward(torch.randn(input_shape).to(device), + result = net.forward(torch.randn(input_shape).to(device), point_coords=input_params.get('point_coords', None), point_labels=input_params.get('point_labels', None), class_vector=input_params.get('class_vector', None)) From 53b0542fcbdb8d38336886d016f6f4e6e417190f Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 13 Aug 2024 11:32:48 +0800 Subject: [PATCH 25/30] fix ci issues Signed-off-by: Yiheng Wang --- monai/networks/nets/vista3d.py | 5 ++-- tests/test_vista3d.py | 42 +++++++++++++++------------------- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index d528cc59a6..8ff43df4c0 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -34,8 +34,9 @@ class VISTA3D(nn.Module): """ - VISTA3D based on `VISTA3D: Versatile Imaging SegmenTation and Annotation model for 3D Computed Tomography - https://arxiv.org/abs/2406.05285>`_. + VISTA3D based on: + `VISTA3D: Versatile Imaging SegmenTation and Annotation model for 3D Computed Tomography + `_. Args: image_encoder: image encoder backbone for feature extraction. diff --git a/tests/test_vista3d.py b/tests/test_vista3d.py index 9472d894b3..11d32a2de7 100644 --- a/tests/test_vista3d.py +++ b/tests/test_vista3d.py @@ -23,39 +23,33 @@ device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASES = [ + [{"encoder_embed_dim": 48, "in_channels": 1}, {}, (1, 1, 64, 64, 64), (1, 1, 64, 64, 64)], + [{"encoder_embed_dim": 48, "in_channels": 2}, {}, (1, 2, 64, 64, 64), (1, 1, 64, 64, 64)], [ {"encoder_embed_dim": 48, "in_channels": 1}, - {}, - (1, 1, 64, 64, 64), - (1, 1, 64, 64, 64), - ], - [ - {"encoder_embed_dim": 48, "in_channels": 2}, - {}, - (1, 2, 64, 64, 64), - (1, 1, 64, 64, 64), - ], - [ - {"encoder_embed_dim": 48, "in_channels": 1}, - {'class_vector': torch.tensor([1,2,3], device=device)}, + {"class_vector": torch.tensor([1, 2, 3], device=device)}, (1, 1, 64, 64, 64), (3, 1, 64, 64, 64), ], [ {"encoder_embed_dim": 48, "in_channels": 1}, - {'point_coords': torch.tensor([[[1,2,3],[1,2,3]]], device=device), - 'point_labels':torch.tensor([[1,0]], device=device)}, + { + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0]], device=device), + }, (1, 1, 64, 64, 64), (1, 1, 64, 64, 64), ], [ {"encoder_embed_dim": 48, "in_channels": 1}, - {'class_vector': torch.tensor([1,2], device=device), - 'point_coords': torch.tensor([[[1,2,3],[1,2,3]],[[1,2,3],[1,2,3]]], device=device), - 'point_labels':torch.tensor([[1,0],[1,0]], device=device)}, + { + "class_vector": torch.tensor([1, 2], device=device), + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0], [1, 0]], device=device), + }, (1, 1, 64, 64, 64), (2, 1, 64, 64, 64), - ] + ], ] @@ -75,10 +69,12 @@ def test_vista3d_shape(self, args, input_params, input_shape, expected_shape): class_head = ClassMappingClassify(n_classes=512, feature_size=args["encoder_embed_dim"], use_mlp=True) net = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head).to(device) with eval_mode(net): - result = net.forward(torch.randn(input_shape).to(device), - point_coords=input_params.get('point_coords', None), - point_labels=input_params.get('point_labels', None), - class_vector=input_params.get('class_vector', None)) + result = net.forward( + torch.randn(input_shape).to(device), + point_coords=input_params.get("point_coords", None), + point_labels=input_params.get("point_labels", None), + class_vector=input_params.get("class_vector", None), + ) self.assertEqual(result.shape, expected_shape) From 3368a4b68851dd9b4f1dc5f12682f4f313a3202f Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 13 Aug 2024 11:52:58 +0800 Subject: [PATCH 26/30] skip old torch test Signed-off-by: Yiheng Wang --- tests/test_vista3d.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_vista3d.py b/tests/test_vista3d.py index 11d32a2de7..d3b4e0c10e 100644 --- a/tests/test_vista3d.py +++ b/tests/test_vista3d.py @@ -19,6 +19,7 @@ from monai.networks import eval_mode from monai.networks.nets import VISTA3D, SegResNetDS2 from monai.networks.nets.vista3d import ClassMappingClassify, PointMappingSAM +from tests.utils import SkipIfBeforePyTorchVersion, skip_if_quick device = "cuda" if torch.cuda.is_available() else "cpu" @@ -53,6 +54,8 @@ ] +@SkipIfBeforePyTorchVersion((1, 11)) +@skip_if_quick class TestVista3d(unittest.TestCase): @parameterized.expand(TEST_CASES) From 5d6e7b118e4ef052ed899a451482bf5b0bdc2e3c Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Tue, 13 Aug 2024 16:20:50 +0800 Subject: [PATCH 27/30] Update monai/networks/nets/vista3d.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> --- monai/networks/nets/vista3d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 8ff43df4c0..fbafca2eed 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -78,8 +78,8 @@ def convert_point_label( Args: point_label: the point label tensor, [B, N]. label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID, - this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot - evaluation, this label_set should be the original index. + this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot + evaluation, this label_set should be the original index. special_index: the special class index that needs to be converted. """ if label_set is None: From 6b53216e9b3d8b3ec5b43e881b8a2ff63330347b Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Tue, 13 Aug 2024 13:55:03 -0400 Subject: [PATCH 28/30] Update docstring and removed unused layer Signed-off-by: heyufan1995 --- monai/networks/nets/segresnet_ds.py | 22 ++++++- monai/networks/nets/vista3d.py | 91 ++++++++++++++++------------- 2 files changed, 70 insertions(+), 43 deletions(-) diff --git a/monai/networks/nets/segresnet_ds.py b/monai/networks/nets/segresnet_ds.py index a7c7e7cd94..be95a909fd 100644 --- a/monai/networks/nets/segresnet_ds.py +++ b/monai/networks/nets/segresnet_ds.py @@ -430,8 +430,24 @@ def forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tenso class SegResNetDS2(SegResNetDS): """ - SegResNetDS2 based on `SegResNetDS` and adds an additional decorder branch. - It is the image encoder used by VISTA3D. + SegResNetDS2 adds an additional decorder branch to SegResNetDS and is the image encoder of VISTA3D + `_. + Args: + spatial_dims: spatial dimension of the input data. Defaults to 3. + init_filters: number of output channels for initial convolution layer. Defaults to 32. + in_channels: number of input channels for the network. Defaults to 1. + out_channels: number of output channels for the network. Defaults to 2. + act: activation type and arguments. Defaults to ``RELU``. + norm: feature normalization type and arguments. Defaults to ``BATCH``. + blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``. + blocks_up: number of upsample blocks (optional). + dsdepth: number of levels for deep supervision. This will be the length of the list of outputs at each scale level. + At dsdepth==1,only a single output is returned. + preprocess: optional callable function to apply before the model's forward pass + resolution: optional input image resolution. When provided, the network will first use non-isotropic kernels to bring + image spacing into an approximately isotropic space. + Otherwise, by default, the kernel size and downsampling is always isotropic. + """ def __init__( @@ -467,7 +483,7 @@ def __init__( self.up_layers_auto = nn.ModuleList([copy.deepcopy(layer) for layer in self.up_layers]) def forward( # type: ignore - self, x: torch.Tensor, with_point: bool = True, with_label: bool = True, **kwargs + self, x: torch.Tensor, with_point: bool = True, with_label: bool = True, ) -> tuple[Union[None, torch.Tensor, list[torch.Tensor]], Union[None, torch.Tensor, list[torch.Tensor]]]: """ Args: diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index fbafca2eed..779c658f85 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -31,6 +31,27 @@ __all__ = ["VISTA3D", "vista3d132"] +def vista3d132(encoder_embed_dim: int = 48, in_channels: int = 1): + """ + Exact VISTA3D network configuration used in https://arxiv.org/abs/2406.05285>`_. + The model treats class index larger than 132 as zero-shot. + + Args: + encoder_embed_dim: hidden dimension for encoder. + in_channels: input channel number. + """ + segresnet = SegResNetDS2( + in_channels=in_channels, + blocks_down=(1, 2, 2, 4, 4), + norm="instance", + out_channels=encoder_embed_dim, + init_filters=encoder_embed_dim, + dsdepth=1, + ) + point_head = PointMappingSAM(feature_size=encoder_embed_dim, n_classes=512, last_supported=132) + class_head = ClassMappingClassify(n_classes=512, feature_size=encoder_embed_dim, use_mlp=True) + vista = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head) + return vista class VISTA3D(nn.Module): """ @@ -55,7 +76,7 @@ def __init__(self, image_encoder: nn.Module, class_head: nn.Module, point_head: self.NINF_VALUE = -9999 self.PINF_VALUE = 9999 - def get_bs(self, class_vector: torch.Tensor | None, point_coords: torch.Tensor | None) -> int: + def get_foreground_class_count(self, class_vector: torch.Tensor | None, point_coords: torch.Tensor | None) -> int: """Get number of foreground classes based on class and point prompt.""" if class_vector is None: if point_coords is None: @@ -305,18 +326,18 @@ def forward( input_images: [1, 1, H, W, D] point_coords: [B, N, 3] point_labels: [B, N], -1 represents padding. 0/1 means negative/positive points for regular class. - 2/3 means negative/postive ponits for special supported class like tumor. + 2/3 means negative/postive ponits for special supported class like tumor. class_vector: [B, 1], the global class index prompt_class: [B, 1], the global class index. This value is associated with point_coords to identify if - the points are for zero-shot or supported class. When class_vector and point_coords are both - provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b] - will be considered novel class. + the points are for zero-shot or supported class. When class_vector and point_coords are both + provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b] + will be considered novel class. patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference. This value is passed from sliding_window_inferer. This is an indicator for training phase or validation phase. labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID, - this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot - evaluation, this label_set should be the original index. + this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot + evaluation, this label_set should be the original index. prev_mask: [B, N, H_fullsize, W_fullsize, D_fullsize]. This is the transposed raw output from sliding_window_inferer before any postprocessing. When user click points to perform auto-results correction, this can be the auto-results. @@ -330,7 +351,7 @@ def forward( if point_coords is None and class_vector is None: return self.NINF_VALUE + torch.zeros([1, 1, *image_size], device=device) - bs = self.get_bs(class_vector, point_coords) + bs = self.get_foreground_class_count(class_vector, point_coords) if patch_coords is not None: # if during validation and perform enable based point-validation. if labels is not None and label_set is not None: @@ -412,10 +433,17 @@ def __init__( self, feature_size: int, max_prompt: int = 32, - num_add_mask_tokens: int = 2, n_classes: int = 512, last_supported: int = 132, ): + """ Interactive point head used for VISTA3D. + Adapted from segment anything Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`. + Args: + feature_size: feature channel from encoder. + max_prompt: max prompt number in each forward iteration. + n_classes: number of classes the model can potentially support. This is the maximum number of class embeddings. + last_supported: number of classes the model support, this value should match the trained model weights. + """ super().__init__() transformer_dim = feature_size self.max_prompt = max_prompt @@ -444,12 +472,6 @@ def __init__( ) self.output_hypernetworks_mlps = MLP(transformer_dim, transformer_dim, transformer_dim, 3) - - # MultiMask output - self.num_add_mask_tokens = num_add_mask_tokens - self.output_add_hypernetworks_mlps = nn.ModuleList( - [MLP(transformer_dim, transformer_dim, transformer_dim, 3) for i in range(self.num_add_mask_tokens)] - ) # class embedding self.n_classes = n_classes self.last_supported = last_supported @@ -464,6 +486,12 @@ def forward( point_labels: torch.Tensor, class_vector: torch.Tensor | None = None, ): + """ Args: + out: feature from encoder, [1, C, H, W, C] + point_coords: point coordinates, [B, N, 3] + point_labels: point labels, [B, N] + class_vector: class prompts, [B] + """ # downsample out out_low = self.feat_downsample(out) out_shape = tuple(out.shape[-3:]) @@ -525,7 +553,14 @@ def forward( class ClassMappingClassify(nn.Module): - def __init__(self, n_classes: int, feature_size: int, use_mlp: bool = False): + """ Class head that performs automatic segmentation based on class vector. + """ + def __init__(self, n_classes: int, feature_size: int, use_mlp: bool = True): + """Args: + n_classes: maximum number of class embedding. + feature_size: class embedding size. + use_mlp: use mlp to further map class embedding. + """ super().__init__() self.use_mlp = use_mlp if use_mlp: @@ -571,30 +606,6 @@ def forward(self, src: torch.Tensor, class_vector: torch.Tensor): return torch.cat(masks, 1), class_embedding - -def vista3d132(encoder_embed_dim: int = 48, in_channels: int = 1): - """ - Exact VISTA3D network configuration used in https://arxiv.org/abs/2406.05285>`_. - The model treats class index larger than 132 as zero-shot. - - Args: - encoder_embed_dim: hidden dimension for encoder. - in_channels: input channel number. - """ - segresnet = SegResNetDS2( - in_channels=in_channels, - blocks_down=(1, 2, 2, 4, 4), - norm="instance", - out_channels=encoder_embed_dim, - init_filters=encoder_embed_dim, - dsdepth=1, - ) - point_head = PointMappingSAM(feature_size=encoder_embed_dim, n_classes=512, last_supported=132) - class_head = ClassMappingClassify(n_classes=512, feature_size=encoder_embed_dim, use_mlp=True) - vista = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head) - return vista - - class TwoWayTransformer(nn.Module): def __init__( self, From cddcd3092c528a3a4107a81d79761a8dd866aacf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 Aug 2024 17:55:38 +0000 Subject: [PATCH 29/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/vista3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 779c658f85..179d5be491 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -442,7 +442,7 @@ def __init__( feature_size: feature channel from encoder. max_prompt: max prompt number in each forward iteration. n_classes: number of classes the model can potentially support. This is the maximum number of class embeddings. - last_supported: number of classes the model support, this value should match the trained model weights. + last_supported: number of classes the model support, this value should match the trained model weights. """ super().__init__() transformer_dim = feature_size From 5fe376c3a05ae5ca5f37fdbd1fb46514c2ac7af7 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 15 Aug 2024 13:19:01 +0800 Subject: [PATCH 30/30] fix format issue Signed-off-by: Yiheng Wang --- monai/networks/nets/segresnet_ds.py | 3 +- monai/networks/nets/vista3d.py | 47 ++++++++++++++--------------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/monai/networks/nets/segresnet_ds.py b/monai/networks/nets/segresnet_ds.py index be95a909fd..1ac5a79ee3 100644 --- a/monai/networks/nets/segresnet_ds.py +++ b/monai/networks/nets/segresnet_ds.py @@ -432,6 +432,7 @@ class SegResNetDS2(SegResNetDS): """ SegResNetDS2 adds an additional decorder branch to SegResNetDS and is the image encoder of VISTA3D `_. + Args: spatial_dims: spatial dimension of the input data. Defaults to 3. init_filters: number of output channels for initial convolution layer. Defaults to 32. @@ -483,7 +484,7 @@ def __init__( self.up_layers_auto = nn.ModuleList([copy.deepcopy(layer) for layer in self.up_layers]) def forward( # type: ignore - self, x: torch.Tensor, with_point: bool = True, with_label: bool = True, + self, x: torch.Tensor, with_point: bool = True, with_label: bool = True ) -> tuple[Union[None, torch.Tensor, list[torch.Tensor]], Union[None, torch.Tensor, list[torch.Tensor]]]: """ Args: diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 179d5be491..fe7f93d493 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -31,6 +31,7 @@ __all__ = ["VISTA3D", "vista3d132"] + def vista3d132(encoder_embed_dim: int = 48, in_channels: int = 1): """ Exact VISTA3D network configuration used in https://arxiv.org/abs/2406.05285>`_. @@ -53,6 +54,7 @@ def vista3d132(encoder_embed_dim: int = 48, in_channels: int = 1): vista = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head) return vista + class VISTA3D(nn.Module): """ VISTA3D based on: @@ -429,20 +431,16 @@ def forward( class PointMappingSAM(nn.Module): - def __init__( - self, - feature_size: int, - max_prompt: int = 32, - n_classes: int = 512, - last_supported: int = 132, - ): - """ Interactive point head used for VISTA3D. - Adapted from segment anything Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`. - Args: - feature_size: feature channel from encoder. - max_prompt: max prompt number in each forward iteration. - n_classes: number of classes the model can potentially support. This is the maximum number of class embeddings. - last_supported: number of classes the model support, this value should match the trained model weights. + def __init__(self, feature_size: int, max_prompt: int = 32, n_classes: int = 512, last_supported: int = 132): + """Interactive point head used for VISTA3D. + Adapted from segment anything: + `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`. + + Args: + feature_size: feature channel from encoder. + max_prompt: max prompt number in each forward iteration. + n_classes: number of classes the model can potentially support. This is the maximum number of class embeddings. + last_supported: number of classes the model support, this value should match the trained model weights. """ super().__init__() transformer_dim = feature_size @@ -486,11 +484,11 @@ def forward( point_labels: torch.Tensor, class_vector: torch.Tensor | None = None, ): - """ Args: - out: feature from encoder, [1, C, H, W, C] - point_coords: point coordinates, [B, N, 3] - point_labels: point labels, [B, N] - class_vector: class prompts, [B] + """Args: + out: feature from encoder, [1, C, H, W, C] + point_coords: point coordinates, [B, N, 3] + point_labels: point labels, [B, N] + class_vector: class prompts, [B] """ # downsample out out_low = self.feat_downsample(out) @@ -553,13 +551,13 @@ def forward( class ClassMappingClassify(nn.Module): - """ Class head that performs automatic segmentation based on class vector. - """ + """Class head that performs automatic segmentation based on class vector.""" + def __init__(self, n_classes: int, feature_size: int, use_mlp: bool = True): """Args: - n_classes: maximum number of class embedding. - feature_size: class embedding size. - use_mlp: use mlp to further map class embedding. + n_classes: maximum number of class embedding. + feature_size: class embedding size. + use_mlp: use mlp to further map class embedding. """ super().__init__() self.use_mlp = use_mlp @@ -606,6 +604,7 @@ def forward(self, src: torch.Tensor, class_vector: torch.Tensor): return torch.cat(masks, 1), class_embedding + class TwoWayTransformer(nn.Module): def __init__( self,