diff --git a/monailabel/tasks/train/basic_train.py b/monailabel/tasks/train/basic_train.py index 8dd2d205e..ccd9aac1f 100644 --- a/monailabel/tasks/train/basic_train.py +++ b/monailabel/tasks/train/basic_train.py @@ -104,6 +104,7 @@ def __init__( stats_path=None, train_save_interval=20, val_interval=1, + n_saved=5, final_filename="checkpoint_final.pt", key_metric_filename="model.pt", model_dict_key="model", @@ -123,6 +124,7 @@ def __init__( :param stats_path: Path to save the train stats :param train_save_interval: checkpoint save interval for training :param val_interval: validation interval (run every x epochs) + :param n_saved: max checkpoints to save :param final_filename: name of final checkpoint that will be saved :param key_metric_filename: best key metric model file name :param model_dict_key: key to save network weights into checkpoint @@ -157,6 +159,7 @@ def __init__( self._train_save_interval = train_save_interval self._val_interval = val_interval + self._n_saved = n_saved self._final_filename = final_filename self._key_metric_filename = key_metric_filename self._model_dict_key = model_dict_key @@ -340,7 +343,7 @@ def config(self): @staticmethod def _validate_transforms(transforms, step="Training", name="pre"): - if not transforms or isinstance(transforms, Compose): + if not transforms or isinstance(transforms, Compose) or callable(transforms): return transforms if isinstance(transforms, list): return Compose(transforms) @@ -528,7 +531,7 @@ def _create_evaluator(self, context: Context): save_dict={self._model_dict_key: context.network}, save_key_metric=True, key_metric_filename=self._key_metric_filename, - n_saved=5, + n_saved=self._n_saved, ) ) @@ -560,7 +563,7 @@ def _create_trainer(self, context: Context): key_metric_filename=f"train_{self._key_metric_filename}" if context.evaluator else self._key_metric_filename, - n_saved=5, + n_saved=self._n_saved, ) ) diff --git a/plugins/qupath/src/main/java/qupath/lib/extension/monailabel/commands/RunInference.java b/plugins/qupath/src/main/java/qupath/lib/extension/monailabel/commands/RunInference.java index 1b75a0e23..119f5f170 100644 --- a/plugins/qupath/src/main/java/qupath/lib/extension/monailabel/commands/RunInference.java +++ b/plugins/qupath/src/main/java/qupath/lib/extension/monailabel/commands/RunInference.java @@ -94,17 +94,20 @@ public void run() { list.addIntParameter("Width", "Width", bbox[2]); list.addIntParameter("Height", "Height", bbox[3]); + boolean override = !info.models.get(selectedModel).nuclick; + list.addBooleanParameter("Override", "Override", override); + if (Dialogs.showParameterDialog("MONAILabel", list)) { String model = (String) list.getChoiceParameterValue("Model"); bbox[0] = list.getIntParameterValue("X").intValue(); bbox[1] = list.getIntParameterValue("Y").intValue(); bbox[2] = list.getIntParameterValue("Width").intValue(); bbox[3] = list.getIntParameterValue("Height").intValue(); + override = list.getBooleanParameterValue("Override").booleanValue(); selectedModel = model; selectedBBox = bbox; - boolean isNuClick = info.models.get(model).nuclick; - runInference(model, new HashSet(Arrays.asList(labels.get(model))), bbox, imageData, isNuClick); + runInference(model, new HashSet(Arrays.asList(labels.get(model))), bbox, imageData, override); } } catch (Exception ex) { ex.printStackTrace(); @@ -145,9 +148,9 @@ private int[] getBBOX(ROI roi) { } private void runInference(String model, Set labels, int[] bbox, ImageData imageData, - boolean isNuClick) throws SAXException, IOException, ParserConfigurationException, InterruptedException { + boolean override) throws SAXException, IOException, ParserConfigurationException, InterruptedException { logger.info("MONAILabel Annotation - Run Inference..."); - logger.info("Model: " + model + "; IsNuClick: " + isNuClick + "; Labels: " + labels); + logger.info("Model: " + model + "; override: " + override + "; Labels: " + labels); String image = Utils.getNameWithoutExtension(imageData.getServerPath()); @@ -163,7 +166,7 @@ private void runInference(String model, Set labels, int[] bbox, ImageDat Document dom = MonaiLabelClient.infer(model, image, req); NodeList annotation_list = dom.getElementsByTagName("Annotation"); - int count = updateAnnotations(labels, annotation_list, roi, imageData, !isNuClick); + int count = updateAnnotations(labels, annotation_list, roi, imageData, override); // Update hierarchy to see changes in QuPath's hierarchy QP.fireHierarchyUpdate(imageData.getHierarchy()); diff --git a/sample-apps/pathology/lib/configs/nuclick.py b/sample-apps/pathology/lib/configs/nuclick.py index 32db53459..42533c2c3 100644 --- a/sample-apps/pathology/lib/configs/nuclick.py +++ b/sample-apps/pathology/lib/configs/nuclick.py @@ -16,7 +16,7 @@ import lib.infers import lib.trainers -from lib.nets import UNet +from monai.networks.nets import BasicUNet from monailabel.interfaces.config import TaskConfig from monailabel.interfaces.tasks.infer import InferTask @@ -42,11 +42,16 @@ def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, ** # Download PreTrained Model if strtobool(self.conf.get("use_pretrained_model", "true")): - url = f"{self.PRE_TRAINED_PATH}/NuClick_UNet_40xAll.pth" + url = f"{self.PRE_TRAINED_PATH}/pathology_nuclick_bunet.pt" download_file(url, self.path[0]) # Network - self.network = UNet(n_channels=5, n_classes=1) + self.network = BasicUNet( + spatial_dims=2, + in_channels=5, + out_channels=1, + features=(32, 64, 128, 256, 512, 32), + ) def infer(self) -> Union[InferTask, Dict[str, InferTask]]: task: InferTask = lib.infers.NuClick( @@ -59,4 +64,21 @@ def infer(self) -> Union[InferTask, Dict[str, InferTask]]: return task def trainer(self) -> Optional[TrainTask]: - return None + output_dir = os.path.join(self.model_dir, self.name) + task: TrainTask = lib.trainers.NuClick( + model_dir=output_dir, + network=self.network, + load_path=self.path[0], + publish_path=self.path[1], + labels=self.labels, + description="Train Nuclei DeepEdit Model", + train_save_interval=1, + config={ + "max_epochs": 10, + "train_batch_size": 64, + "dataset_max_region": (10240, 10240), + "dataset_limit": 0, + "dataset_randomize": True, + }, + ) + return task diff --git a/sample-apps/pathology/lib/handlers.py b/sample-apps/pathology/lib/handlers.py index 33f19154b..f19484ca0 100644 --- a/sample-apps/pathology/lib/handlers.py +++ b/sample-apps/pathology/lib/handlers.py @@ -141,7 +141,7 @@ def write_images(self, batch_data, output_data, epoch): label[y == region] = region self.logger.info( - "{} - {} - Image: {}; Label: {} (nz: {}); Pred: {} (nz: {})".format( + "{} - {} - Image: {}; Label: {} (nz: {}); Pred: {} (nz: {}); Sig: (pos-nz: {}, neg-nz: {})".format( bidx, region, image.shape, @@ -149,6 +149,8 @@ def write_images(self, batch_data, output_data, epoch): np.count_nonzero(label), y_pred.shape, np.count_nonzero(y_pred[region]), + np.count_nonzero(image[3]) if image.shape == 5 else 0, + np.count_nonzero(image[4]) if image.shape == 5 else 0, ) ) @@ -172,15 +174,15 @@ def write_images(self, batch_data, output_data, epoch): break def write_region_metrics(self, epoch): - metric_sum = 0 - for region in self.metric_data: - metric = self.metric_data[region].mean() - self.logger.info(f"Epoch[{epoch}] Metrics -- Region: {region:0>2d}, {self.tag_name}: {metric:.4f}") + if len(self.metric_data) > 1: + metric_sum = 0 + for region in self.metric_data: + metric = self.metric_data[region].mean() + self.logger.info(f"Epoch[{epoch}] Metrics -- Region: {region:0>2d}, {self.tag_name}: {metric:.4f}") - self.writer.add_scalar(f"dice_{region:0>2d}", metric, epoch) - metric_sum += metric + self.writer.add_scalar(f"dice_{region:0>2d}", metric, epoch) + metric_sum += metric - if len(self.metric_data) > 1: metric_avg = metric_sum / len(self.metric_data) self.writer.add_scalar("dice_regions_avg", metric_avg, epoch) diff --git a/sample-apps/pathology/lib/infers/nuclick.py b/sample-apps/pathology/lib/infers/nuclick.py index d42156a6a..227c3f28d 100644 --- a/sample-apps/pathology/lib/infers/nuclick.py +++ b/sample-apps/pathology/lib/infers/nuclick.py @@ -130,8 +130,7 @@ def __call__(self, data): return d @staticmethod - def get_clickmap_boundingbox(cx, cy, m, n): - bb = 128 + def get_clickmap_boundingbox(cx, cy, m, n, bb=128): click_map = np.zeros((m, n), dtype=np.uint8) # Removing points out of image dimension (these points may have been clicked unwanted) @@ -162,9 +161,7 @@ def get_clickmap_boundingbox(cx, cy, m, n): return click_map, bounding_boxes @staticmethod - def get_patches_and_signals(img, click_map, bounding_boxes, cx, cy, m, n): - bb = 128 - + def get_patches_and_signals(img, click_map, bounding_boxes, cx, cy, m, n, bb=128): # total = number of clicks total = len(bounding_boxes) img = np.array([img]) # img.shape=(1,3,m,n) diff --git a/sample-apps/pathology/lib/nets/__init__.py b/sample-apps/pathology/lib/nets/__init__.py deleted file mode 100644 index 2bca7a6c9..000000000 --- a/sample-apps/pathology/lib/nets/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from .unet import UNet diff --git a/sample-apps/pathology/lib/nets/unet.py b/sample-apps/pathology/lib/nets/unet.py deleted file mode 100644 index 60366523b..000000000 --- a/sample-apps/pathology/lib/nets/unet.py +++ /dev/null @@ -1,105 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class DoubleConv(nn.Module): - """(convolution => [BN] => ReLU) * 2""" - - def __init__(self, in_channels, out_channels, mid_channels=None): - super().__init__() - if not mid_channels: - mid_channels = out_channels - self.double_conv = nn.Sequential( - nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(mid_channels), - nn.ReLU(inplace=True), - nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - ) - - def forward(self, x): - return self.double_conv(x) - - -class Down(nn.Module): - """Downscaling with maxpool then double conv""" - - def __init__(self, in_channels, out_channels): - super().__init__() - self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)) - - def forward(self, x): - return self.maxpool_conv(x) - - -class Up(nn.Module): - """Upscaling then double conv""" - - def __init__(self, in_channels, out_channels, bilinear=True): - super().__init__() - - # if bilinear, use the normal convolutions to reduce the number of channels - if bilinear: - self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) - self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) - else: - self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) - self.conv = DoubleConv(in_channels, out_channels) - - def forward(self, x1, x2): - x1 = self.up(x1) - # input is CHW - diffY = x2.size()[2] - x1.size()[2] - diffX = x2.size()[3] - x1.size()[3] - - x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) - # if you have padding issues, see - # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a - # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd - x = torch.cat([x2, x1], dim=1) - return self.conv(x) - - -class OutConv(nn.Module): - def __init__(self, in_channels, out_channels): - super().__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) - - def forward(self, x): - return self.conv(x) - - -class UNet(nn.Module): - def __init__(self, n_channels, n_classes, bilinear=True): - super().__init__() - self.net_name = "UNet" - self.n_channels = n_channels - self.n_classes = n_classes - self.bilinear = bilinear - - self.inc = DoubleConv(n_channels, 64) - self.down1 = Down(64, 128) - self.down2 = Down(128, 256) - self.down3 = Down(256, 512) - factor = 2 if bilinear else 1 - self.down4 = Down(512, 1024 // factor) - self.up1 = Up(1024, 512 // factor, bilinear) - self.up2 = Up(512, 256 // factor, bilinear) - self.up3 = Up(256, 128 // factor, bilinear) - self.up4 = Up(128, 64, bilinear) - self.outc = OutConv(64, n_classes) - - def forward(self, x): - x1 = self.inc(x) - x2 = self.down1(x1) - x3 = self.down2(x2) - x4 = self.down3(x3) - x5 = self.down4(x4) - x = self.up1(x5, x4) - x = self.up2(x, x3) - x = self.up3(x, x2) - x = self.up4(x, x1) - logits = self.outc(x) - return logits diff --git a/sample-apps/pathology/lib/trainers/__init__.py b/sample-apps/pathology/lib/trainers/__init__.py index 6463e5137..438bef6d5 100644 --- a/sample-apps/pathology/lib/trainers/__init__.py +++ b/sample-apps/pathology/lib/trainers/__init__.py @@ -10,4 +10,5 @@ # limitations under the License. from .deepedit_nuclei import DeepEditNuclei +from .nuclick import NuClick from .segmentation_nuclei import SegmentationNuclei diff --git a/sample-apps/pathology/lib/trainers/nuclick.py b/sample-apps/pathology/lib/trainers/nuclick.py new file mode 100644 index 000000000..725d34788 --- /dev/null +++ b/sample-apps/pathology/lib/trainers/nuclick.py @@ -0,0 +1,298 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math +import os +import random + +import cv2 +import numpy as np +import skimage +import torch +from ignite.metrics import Accuracy +from lib.handlers import TensorBoardImageHandler +from lib.transforms import FilterImaged +from lib.utils import split_dataset, split_nuclei_dataset +from monai.config import KeysCollection +from monai.handlers import from_engine +from monai.inferers import SimpleInferer +from monai.losses import DiceLoss +from monai.transforms import ( + Activationsd, + AddChanneld, + AsChannelFirstd, + AsDiscreted, + EnsureTyped, + LoadImaged, + MapTransform, + RandomizableTransform, + RandRotate90d, + ScaleIntensityRangeD, + ToNumpyd, + TorchVisiond, + ToTensord, + Transform, +) +from tqdm import tqdm + +from monailabel.interfaces.datastore import Datastore +from monailabel.tasks.train.basic_train import BasicTrainTask, Context + +logger = logging.getLogger(__name__) + + +class NuClick(BasicTrainTask): + def __init__( + self, + model_dir, + network, + labels, + tile_size=(256, 256), + patch_size=128, + min_area=5, + description="Pathology NuClick Segmentation", + **kwargs, + ): + self._network = network + self.labels = labels + self.tile_size = tile_size + self.patch_size = patch_size + self.min_area = min_area + super().__init__(model_dir, description, **kwargs) + + def network(self, context: Context): + return self._network + + def optimizer(self, context: Context): + return torch.optim.Adam(context.network.parameters(), 0.0001) + + def loss_function(self, context: Context): + return DiceLoss(sigmoid=True, squared_pred=True) + + def pre_process(self, request, datastore: Datastore): + self.cleanup(request) + + cache_dir = os.path.join(self.get_cache_dir(request), "train_ds") + source = request.get("dataset_source") + max_region = request.get("dataset_max_region", (10240, 10240)) + max_region = (max_region, max_region) if isinstance(max_region, int) else max_region[:2] + + ds = split_dataset( + datastore=datastore, + cache_dir=cache_dir, + source=source, + groups=self.labels, + tile_size=self.tile_size, + max_region=max_region, + limit=request.get("dataset_limit", 0), + randomize=request.get("dataset_randomize", True), + ) + + logger.info(f"Split data (len: {len(ds)}) based on each nuclei") + ds_new = [] + limit = request.get("dataset_limit", 0) + for d in tqdm(ds): + ds_new.extend(split_nuclei_dataset(d, min_area=self.min_area)) + if 0 < limit < len(ds_new): + ds_new = ds_new[:limit] + break + return ds_new + + def train_pre_transforms(self, context: Context): + return [ + LoadImaged(keys=("image", "label"), dtype=np.uint8), + FilterImaged(keys="image", min_size=5), + FlattenLabeld(keys="label"), + AsChannelFirstd(keys="image"), + AddChanneld(keys="label"), + ExtractPatchd(keys=("image", "label"), patch_size=self.patch_size), + SplitLabeld(label="label", others="others", mask_value="mask_value", min_area=self.min_area), + ToTensord(keys="image"), + TorchVisiond( + keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04 + ), + ToNumpyd(keys="image"), + RandRotate90d(keys=("image", "label", "others"), prob=0.5, spatial_axes=(0, 1)), + ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), + AddPointGuidanceSignald(image="image", label="label", others="others"), + EnsureTyped(keys=("image", "label")), + ] + + def train_post_transforms(self, context: Context): + return [ + Activationsd(keys="pred", sigmoid=True), + AsDiscreted(keys="pred", threshold_values=True, logit_thresh=0.5), + ] + + def train_key_metric(self, context: Context): + return {"train_acc": Accuracy(output_transform=from_engine(["pred", "label"]))} + + def val_key_metric(self, context: Context): + return {"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))} + + def val_inferer(self, context: Context): + return SimpleInferer() + + def train_handlers(self, context: Context): + handlers = super().train_handlers(context) + if context.local_rank == 0: + handlers.append(TensorBoardImageHandler(log_dir=context.events_dir, batch_limit=4)) + return handlers + + +class FlattenLabeld(MapTransform): + def __call__(self, data): + d = dict(data) + for key in self.keys: + _, labels, _, _ = cv2.connectedComponentsWithStats(d[key], 4, cv2.CV_32S) + d[key] = labels.astype(np.uint8) + return d + + +class ExtractPatchd(MapTransform): + def __init__(self, keys: KeysCollection, centroid_key="centroid", patch_size=128): + super().__init__(keys) + self.centroid_key = centroid_key + self.patch_size = patch_size + + def __call__(self, data): + d = dict(data) + + centroid = d[self.centroid_key] # create mask based on centroid (select nuclei based on centroid) + roi_size = (self.patch_size, self.patch_size) + + for key in self.keys: + img = d[key] + x_start, x_end, y_start, y_end = self.bbox(self.patch_size, centroid, img.shape[-2:]) + cropped = img[:, x_start:x_end, y_start:y_end] + d[key] = self.pad_to_shape(cropped, roi_size) + return d + + @staticmethod + def bbox(patch_size, centroid, size): + x, y = centroid + m, n = size + + x_start = int(max(x - patch_size / 2, 0)) + y_start = int(max(y - patch_size / 2, 0)) + x_end = x_start + patch_size + y_end = y_start + patch_size + if x_end > m: + x_end = m + x_start = m - patch_size + if y_end > n: + y_end = n + y_start = n - patch_size + return x_start, x_end, y_start, y_end + + @staticmethod + def pad_to_shape(img, shape): + img_shape = img.shape[-2:] + s_diff = np.array(shape) - np.array(img_shape) + diff = [(0, 0), (0, s_diff[0]), (0, s_diff[1])] + return np.pad( + img, + diff, + mode="constant", + constant_values=0, + ) + + +class SplitLabeld(Transform): + def __init__(self, label="label", others="others", mask_value="mask_value", min_area=5): + self.label = label + self.others = others + self.mask_value = mask_value + self.min_area = min_area + + def __call__(self, data): + d = dict(data) + label = d[self.label] + mask_value = d[self.mask_value] + + mask = np.uint8(label == mask_value) + others = (1 - mask) * label + others = self._mask_relabeling(others[0], min_area=self.min_area)[np.newaxis] + + d[self.label] = mask + d[self.others] = others + return d + + @staticmethod + def _mask_relabeling(mask, min_area=5): + res = np.zeros_like(mask) + for l in np.unique(mask): + if l == 0: + continue + + m = skimage.measure.label(mask == l, connectivity=1) + for stat in skimage.measure.regionprops(m): + if stat.area > min_area: + res[stat.coords[:, 0], stat.coords[:, 1]] = l + return res + + +class AddPointGuidanceSignald(RandomizableTransform): + def __init__(self, image="image", label="label", others="others", drop_rate=0.5, jitter_range=3): + super().__init__() + + self.image = image + self.label = label + self.others = others + self.drop_rate = drop_rate + self.jitter_range = jitter_range + + def __call__(self, data): + d = dict(data) + + image = d[self.image] + mask = d[self.label] + others = d[self.others] + + inc_sig = self.inclusion_map(mask[0]) + exc_sig = self.exclusion_map(others[0], drop_rate=self.drop_rate, jitter_range=self.jitter_range) + + image = np.concatenate((image, inc_sig[np.newaxis, ...], exc_sig[np.newaxis, ...]), axis=0) + d[self.image] = image + return d + + @staticmethod + def inclusion_map(mask): + point_mask = np.zeros_like(mask) + indices = np.argwhere(mask > 0) + if len(indices) > 0: + idx = np.random.randint(0, len(indices)) + point_mask[indices[idx, 0], indices[idx, 1]] = 1 + + return point_mask + + @staticmethod + def exclusion_map(others, jitter_range=3, drop_rate=0.5): + point_mask = np.zeros_like(others) + max_x = point_mask.shape[0] - 1 + max_y = point_mask.shape[1] - 1 + + stats = skimage.measure.regionprops(others) + for stat in stats: + x, y = stat.centroid + # random drop + if np.random.choice([True, False], p=[drop_rate, 1 - drop_rate]): + continue + + # random jitter + x = int(math.floor(x)) + random.randint(a=-jitter_range, b=jitter_range) + y = int(math.floor(y)) + random.randint(a=-jitter_range, b=jitter_range) + x = min(max(0, x), max_x) + y = min(max(0, y), max_y) + point_mask[x, y] = 1 + + return point_mask diff --git a/sample-apps/pathology/lib/utils.py b/sample-apps/pathology/lib/utils.py index 3b898730a..e4eb7aca6 100644 --- a/sample-apps/pathology/lib/utils.py +++ b/sample-apps/pathology/lib/utils.py @@ -1,33 +1,51 @@ +import copy import logging +import math import os import random import shutil -import xml.etree.ElementTree as ET +import xml.etree.ElementTree from io import BytesIO from math import ceil import cv2 import numpy as np import openslide +from monai.transforms import LoadImage from PIL import Image +from skimage.measure import regionprops from tqdm import tqdm from monailabel.datastore.dsa import DSADatastore from monailabel.datastore.local import LocalDatastore +from monailabel.interfaces.datastore import Datastore from monailabel.utils.others.generic import get_basename logger = logging.getLogger(__name__) -def split_dataset(datastore, cache_dir, source, groups, tile_size, max_region=(10240, 10240), limit=0, randomize=True): +def split_dataset( + datastore: Datastore, cache_dir, source, groups, tile_size, max_region=(10240, 10240), limit=0, randomize=True +): ds = datastore.datalist() shutil.rmtree(cache_dir, ignore_errors=True) - if source == "pannuke": + if source == "none": + pass + elif source == "pannuke": image = np.load(ds[0]["image"]) if len(ds) == 1 else None if image is not None and len(image.shape) > 3: logger.info(f"PANNuke (For Developer Mode only):: Split data; groups: {groups}") ds = split_pannuke_dataset(ds[0]["image"], ds[0]["label"], cache_dir, groups) + elif source == "nuclick": + logger.info("Split data based on each nuclei") + ds_new = [] + for d in tqdm(ds): + ds_new.extend(split_nuclei_dataset(d)) + if 0 < limit < len(ds_new): + ds_new = ds_new[:limit] + break + ds = ds_new else: logger.info(f"Split data based on tile size: {tile_size}; groups: {groups}") ds_new = [] @@ -156,7 +174,7 @@ def split_local_dataset(datastore, d, output_dir, groups, tile_size, max_region= points = [] polygons = {g: [] for g in groups} - annotations_xml = ET.parse(d["label"]).getroot() + annotations_xml = xml.etree.ElementTree.parse(d["label"]).getroot() for annotation in annotations_xml.iter("Annotation"): g = annotation.get("PartOfGroup") g = g if g else "None" @@ -183,6 +201,31 @@ def split_local_dataset(datastore, d, output_dir, groups, tile_size, max_region= return dataset_json +def split_nuclei_dataset(d, centroid_key="centroid", mask_value_key="mask_value", min_area=5): + dataset_json = [] + + mask = LoadImage(image_only=True, dtype=np.uint8)(d["label"]) + _, labels, _, _ = cv2.connectedComponentsWithStats(mask, 4, cv2.CV_32S) + + stats = regionprops(labels) + for stat in stats: + if stat.area < min_area: + logger.debug(f"++++ Ignored label with smaller area => ( {stat.area} < {min_area})") + continue + + x, y = stat.centroid + x = int(math.floor(x)) + y = int(math.floor(y)) + + item = copy.deepcopy(d) + item[centroid_key] = (x, y) + item[mask_value_key] = stat.label + + # logger.info(f"{d['label']} => {len(stats)} => {mask.shape} => {stat.label}") + dataset_json.append(item) + return dataset_json + + def _group_item(groups, d, output_dir): groups = groups if groups else dict() groups = [groups] if isinstance(groups, str) else groups @@ -310,7 +353,7 @@ def main_dsa(): datastore = DSADatastore(api_url, folder, api_key, annotation_groups, asset_store_path) print(json.dumps(datastore.datalist(), indent=2)) - split_dataset(datastore, "/localhome/sachi/Downloads/dsa/mostly_tumor", None, annotation_groups, (256, 256)) + split_dataset(datastore, "/localhome/sachi/Downloads/dsa/mostly_tumor", "", annotation_groups, (256, 256)) def main_nuke(): @@ -322,8 +365,15 @@ def main_nuke(): datefmt="%Y-%m-%d %H:%M:%S", ) - datastore = LocalDatastore("/localhome/sachi/Data/Pathology/PanNuke", extensions=("*.nii.gz", "*.nii", "*.npy")) - split_dataset(datastore, "/localhome/sachi/Data/Pathology/PanNukeF", "pannuke", "Nuclei", None) + datastore = LocalDatastore("/localhome/sachi/Datasets/pannuke", extensions=("*.npy")) + labels = { + "Neoplastic cells": 1, + "Inflammatory": 2, + "Connective/Soft tissue cells": 3, + "Dead Cells": 4, + "Epithelial": 5, + } + split_dataset(datastore, "/localhome/sachi/Datasets/pannukeF", "pannuke", labels, None) def main_local(): @@ -341,9 +391,24 @@ def main_local(): datastore = LocalDatastore("C:\\Projects\\Pathology\\Test", extensions=("*.svs", "*.xml")) print(json.dumps(datastore.datalist(), indent=2)) - split_dataset(datastore, "C:\\Projects\\Pathology\\TestF", None, annotation_groups, (256, 256)) + split_dataset(datastore, "C:\\Projects\\Pathology\\TestF", "", annotation_groups, (256, 256)) # print(json.dumps(ds, indent=2)) +def main_nuclei(): + from monailabel.datastore.local import LocalDatastore + + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] [%(process)s] [%(threadName)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # s = "/localhome/sachi/Datasets/NuClick" + s = "/localhome/sachi/Datasets/pannukeF" + datastore = LocalDatastore(s, extensions=("*.png", "*.npy")) + split_dataset(datastore, None, "nuclick", None, None, limit=0) + + if __name__ == "__main__": - main_local() + main_nuclei() diff --git a/sample-apps/pathology/main.py b/sample-apps/pathology/main.py index 0688b3fe5..fd959dfa8 100644 --- a/sample-apps/pathology/main.py +++ b/sample-apps/pathology/main.py @@ -160,14 +160,9 @@ def main(): datefmt="%Y-%m-%d %H:%M:%S", ) - run_train = False + run_train = True home = str(Path.home()) - if run_train: - # studies = f"{home}/Data/Pathology/PanNuke" - studies = "http://0.0.0.0:8080/api/v1" - else: - # studies = f"{home}/Data/Pathology/Test" - studies = "C:\\Projects\\Pathology\\Test" + studies = f"{home}/Datasets/pannukeF" parser = argparse.ArgumentParser() parser.add_argument("-s", "--studies", default=studies) @@ -176,12 +171,32 @@ def main(): app_dir = os.path.dirname(__file__) studies = args.studies - app = MyApp(app_dir, studies, {}) + app = MyApp(app_dir, studies, {"roi_size": "[1024,1024]", "preload": "false"}) if run_train: - train(app) + train_nuclick(app) else: - infer_nuclick(app) - # infer_wsi(app) + # infer_nuclick(app) + infer_wsi(app) + + +def train_nuclick(app): + model = "nuclick" + app.train( + request={ + "name": "train_01", + "model": model, + "max_epochs": 10, + "dataset": "PersistentDataset", # PersistentDataset, CacheDataset + "train_batch_size": 128, + "val_batch_size": 64, + "multi_gpu": True, + "val_split": 0.2, + "dataset_source": "none", + "dataset_limit": 0, + "pretrained": False, + "n_saved": 10, + }, + ) def train(app): @@ -241,10 +256,10 @@ def infer_wsi(app): home = str(Path.home()) - root_dir = f"{home}/Data/Pathology" + root_dir = f"{home}/Datasets/" image = "TCGA-02-0010-01Z-00-DX4.07de2e55-a8fe-40ee-9e98-bcb78050b9f7" - output = "dsa" + output = "asap" # slide = openslide.OpenSlide(f"{app.studies}/{image}.svs") # img = slide.read_region((7737, 20086), 0, (2048, 2048)).convert("RGB") @@ -259,7 +274,7 @@ def infer_wsi(app): "level": 0, "location": [0, 0], "size": [0, 0], - "tile_size": [2048, 2048], + "tile_size": [1024, 1024], "min_poly_area": 40, "gpus": "all", "multi_gpu": True,