From a7e02f69248d79dd32a81366aaec016bce186359 Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Thu, 5 May 2022 05:50:08 -0700 Subject: [PATCH 1/7] Draft Training workflow for NuClick Signed-off-by: Sachidanand Alle --- sample-apps/pathology/lib/configs/nuclick.py | 18 +- sample-apps/pathology/lib/infers/nuclick.py | 7 +- sample-apps/pathology/lib/nets/nuclick.py | 486 ++++++++++++++++++ .../pathology/lib/trainers/__init__.py | 1 + sample-apps/pathology/lib/trainers/nuclick.py | 357 +++++++++++++ sample-apps/pathology/lib/utils.py | 56 +- sample-apps/pathology/main.py | 42 +- 7 files changed, 944 insertions(+), 23 deletions(-) create mode 100644 sample-apps/pathology/lib/nets/nuclick.py create mode 100644 sample-apps/pathology/lib/trainers/nuclick.py diff --git a/sample-apps/pathology/lib/configs/nuclick.py b/sample-apps/pathology/lib/configs/nuclick.py index 32db53459..fe61e3149 100644 --- a/sample-apps/pathology/lib/configs/nuclick.py +++ b/sample-apps/pathology/lib/configs/nuclick.py @@ -59,4 +59,20 @@ def infer(self) -> Union[InferTask, Dict[str, InferTask]]: return task def trainer(self) -> Optional[TrainTask]: - return None + output_dir = os.path.join(self.model_dir, self.name) + task: TrainTask = lib.trainers.NuClick( + model_dir=output_dir, + network=self.network, + load_path=self.path[0], + publish_path=self.path[1], + labels=self.labels, + description="Train Nuclei DeepEdit Model", + config={ + "max_epochs": 10, + "train_batch_size": 64, + "dataset_max_region": (10240, 10240), + "dataset_limit": 0, + "dataset_randomize": True, + }, + ) + return task diff --git a/sample-apps/pathology/lib/infers/nuclick.py b/sample-apps/pathology/lib/infers/nuclick.py index d42156a6a..227c3f28d 100644 --- a/sample-apps/pathology/lib/infers/nuclick.py +++ b/sample-apps/pathology/lib/infers/nuclick.py @@ -130,8 +130,7 @@ def __call__(self, data): return d @staticmethod - def get_clickmap_boundingbox(cx, cy, m, n): - bb = 128 + def get_clickmap_boundingbox(cx, cy, m, n, bb=128): click_map = np.zeros((m, n), dtype=np.uint8) # Removing points out of image dimension (these points may have been clicked unwanted) @@ -162,9 +161,7 @@ def get_clickmap_boundingbox(cx, cy, m, n): return click_map, bounding_boxes @staticmethod - def get_patches_and_signals(img, click_map, bounding_boxes, cx, cy, m, n): - bb = 128 - + def get_patches_and_signals(img, click_map, bounding_boxes, cx, cy, m, n, bb=128): # total = number of clicks total = len(bounding_boxes) img = np.array([img]) # img.shape=(1,3,m,n) diff --git a/sample-apps/pathology/lib/nets/nuclick.py b/sample-apps/pathology/lib/nets/nuclick.py new file mode 100644 index 000000000..b8d6ebe03 --- /dev/null +++ b/sample-apps/pathology/lib/nets/nuclick.py @@ -0,0 +1,486 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +bn_axis = 1 + + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(mid_channels), + nn.ReLU(inplace=True), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + return self.double_conv(x) + + +class Down(nn.Module): + """Downscaling with maxpool then double conv""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)) + + def forward(self, x): + return self.maxpool_conv(x) + + +class Up(nn.Module): + """Upscaling then double conv""" + + def __init__(self, in_channels, out_channels, bilinear=True): + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) + else: + self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + # input is CHW + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) + # if you have padding issues, see + # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a + # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x): + return self.conv(x) + + +# --------------------New layers development in progress----------------- +"""(convolution => [BN] => ReLU/sigmoid)""" +"""No regularizer""" + + +class Conv_Bn_Relu(nn.Module): + def __init__( + self, + in_channels, + out_channels=32, + kernelSize=(3, 3), + strds=(1, 1), + useBias=False, + dilatationRate=(1, 1), + actv="relu", + doBatchNorm=True, + ): + + super().__init__() + if isinstance(kernelSize, int): + kernelSize = (kernelSize, kernelSize) + if isinstance(strds, int): + strds = (strds, strds) + + self.conv_bn_relu = self.get_block( + in_channels, out_channels, kernelSize, strds, useBias, dilatationRate, actv, doBatchNorm + ) + + def forward(self, input): + return self.conv_bn_relu(input) + + def get_block(self, in_channels, out_channels, kernelSize, strds, useBias, dilatationRate, actv, doBatchNorm): + + layers = [] + + conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernelSize, + stride=strds, + dilation=dilatationRate, + bias=useBias, + padding="same", + padding_mode="zeros", + ) + + if actv == "selu": + # (Can't find 'lecun_normal' equivalent in PyTorch) + torch.nn.init.xavier_normal_(conv1.weight) + else: + torch.nn.init.xavier_uniform_(conv1.weight) + + layers.append(conv1) + + if actv != "selu" and doBatchNorm: + layers.append(nn.BatchNorm2d(num_features=out_channels, eps=1.001e-5)) + + if actv == "relu": + layers.append(nn.ReLU()) + elif actv == "sigmoid": + layers.append(nn.Sigmoid()) + elif actv == "selu": + layers.append(nn.SELU()) + + block = nn.Sequential(*layers) + return block + + +"""Multiscale Conv Block""" + + +class Multiscale_Conv_Block(nn.Module): + def __init__( + self, + in_channels, + kernelSizes, + dilatationRates, + out_channels=32, + strds=(1, 1), + actv="relu", + useBias=False, + isDense=True, + ): + + super().__init__() + + # Initialise conv blocks + if isDense: + self.conv_block_0 = Conv_Bn_Relu( + in_channels=in_channels, + out_channels=4 * out_channels, + kernelSize=1, + strds=strds, + actv=actv, + useBias=useBias, + ) + self.conv_block_5 = Conv_Bn_Relu( + in_channels=in_channels, + out_channels=out_channels, + kernelSize=3, + strds=strds, + actv=actv, + useBias=useBias, + ) + else: + self.conv_block_0 = None + self.conv_block_5 = None + + self.conv_block_1 = Conv_Bn_Relu( + in_channels=in_channels, + out_channels=out_channels, + kernelSize=kernelSizes[0], + strds=strds, + actv=actv, + useBias=useBias, + dilatationRate=(dilatationRates[0], dilatationRates[0]), + ) + + self.conv_block_2 = Conv_Bn_Relu( + in_channels=in_channels, + out_channels=out_channels, + kernelSize=kernelSizes[1], + strds=strds, + actv=actv, + useBias=useBias, + dilatationRate=(dilatationRates[1], dilatationRates[1]), + ) + + self.conv_block_3 = Conv_Bn_Relu( + in_channels=in_channels, + out_channels=out_channels, + kernelSize=kernelSizes[2], + strds=strds, + actv=actv, + useBias=useBias, + dilatationRate=(dilatationRates[2], dilatationRates[2]), + ) + + self.conv_block_4 = Conv_Bn_Relu( + in_channels=in_channels, + out_channels=out_channels, + kernelSize=kernelSizes[3], + strds=strds, + actv=actv, + useBias=useBias, + dilatationRate=(dilatationRates[3], dilatationRates[3]), + ) + + def forward(self, input_map): + # If isDense == True + if self.conv_block_0 is not None: + conv0 = self.conv_block_0(input_map) + else: + conv0 = input_map + + conv1 = self.conv_block_1(conv0) + conv2 = self.conv_block_2(conv0) + conv3 = self.conv_block_3(conv0) + conv4 = self.conv_block_4(conv0) + + # (Not sure about bn_axis) + output_map = torch.cat([conv1, conv2, conv3, conv4], dim=bn_axis) + + # If isDense == True + if self.conv_block_5 is not None: + output_map = self.conv_block_5(output_map) + # (Not sure about bn_axis) + output_map = torch.cat([input_map, output_map], dim=bn_axis) + + return output_map + + +"""Residual_Conv""" + + +class Residual_Conv(nn.Module): + def __init__( + self, + in_channels, + out_channels=32, + kernelSize=(3, 3), + strds=(1, 1), + actv="relu", + useBias=False, + dilatationRate=(1, 1), + ): + super().__init__() + + if actv == "selu": + self.conv_block_1 = Conv_Bn_Relu( + in_channels, + out_channels, + kernelSize=kernelSize, + strds=strds, + actv="None", + useBias=useBias, + dilatationRate=dilatationRate, + doBatchNorm=False, + ) + self.conv_block_2 = Conv_Bn_Relu( + in_channels, + out_channels, + kernelSize=kernelSize, + strds=strds, + actv="None", + useBias=useBias, + dilatationRate=dilatationRate, + doBatchNorm=False, + ) + self.activation = nn.SELU() + else: + self.conv_block_1 = Conv_Bn_Relu( + in_channels, + out_channels, + kernelSize=kernelSize, + strds=strds, + actv="None", + useBias=useBias, + dilatationRate=dilatationRate, + doBatchNorm=True, + ) + self.conv_block_2 = Conv_Bn_Relu( + out_channels, + out_channels, + kernelSize=kernelSize, + strds=strds, + actv="None", + useBias=useBias, + dilatationRate=dilatationRate, + doBatchNorm=True, + ) + + if actv == "relu": + self.activation = nn.ReLU() + + if actv == "sigmoid": + self.activation = nn.Sigmoid() + + def forward(self, input): + conv1 = self.conv_block_1(input) + conv2 = self.conv_block_2(conv1) + + out = torch.add(conv1, conv2) + out = self.activation(out) + return out + + +class NuClick_NN(nn.Module): + def __init__(self, n_channels, n_classes): + super().__init__() + self.net_name = "NuClick" + + self.n_channels = n_channels + self.n_classes = n_classes + + # -------------Conv_Bn_Relu blocks------------ + self.conv_block_1 = nn.Sequential( + Conv_Bn_Relu(in_channels=self.n_channels, out_channels=64, kernelSize=7), + Conv_Bn_Relu(in_channels=64, out_channels=32, kernelSize=5), + Conv_Bn_Relu(in_channels=32, out_channels=32, kernelSize=3), + ) + + self.conv_block_2 = nn.Sequential( + Conv_Bn_Relu(in_channels=64, out_channels=64), + Conv_Bn_Relu(in_channels=64, out_channels=32), + Conv_Bn_Relu(in_channels=32, out_channels=32), + ) + + self.conv_block_3 = Conv_Bn_Relu( + in_channels=32, out_channels=self.n_classes, kernelSize=(1, 1), actv=None, useBias=True, doBatchNorm=False + ) + + # -------------Residual_Conv blocks------------ + self.residual_block_1 = nn.Sequential( + Residual_Conv(in_channels=32, out_channels=64), Residual_Conv(in_channels=64, out_channels=64) + ) + + self.residual_block_2 = Residual_Conv(in_channels=64, out_channels=128) + + self.residual_block_3 = Residual_Conv(in_channels=128, out_channels=128) + + self.residual_block_4 = nn.Sequential( + Residual_Conv(in_channels=128, out_channels=256), + Residual_Conv(in_channels=256, out_channels=256), + Residual_Conv(in_channels=256, out_channels=256), + ) + + self.residual_block_5 = nn.Sequential( + Residual_Conv(in_channels=256, out_channels=512), + Residual_Conv(in_channels=512, out_channels=512), + Residual_Conv(in_channels=512, out_channels=512), + ) + + self.residual_block_6 = nn.Sequential( + Residual_Conv(in_channels=512, out_channels=1024), Residual_Conv(in_channels=1024, out_channels=1024) + ) + + self.residual_block_7 = nn.Sequential( + Residual_Conv(in_channels=1024, out_channels=512), Residual_Conv(in_channels=512, out_channels=256) + ) + + self.residual_block_8 = Residual_Conv(in_channels=512, out_channels=256) + + self.residual_block_9 = Residual_Conv(in_channels=256, out_channels=256) + + self.residual_block_10 = nn.Sequential( + Residual_Conv(in_channels=256, out_channels=128), Residual_Conv(in_channels=128, out_channels=128) + ) + + self.residual_block_11 = Residual_Conv(in_channels=128, out_channels=64) + + self.residual_block_12 = Residual_Conv(in_channels=64, out_channels=64) + + # -------------Multiscale_Conv_Block blocks------------ + self.multiscale_block_1 = Multiscale_Conv_Block( + in_channels=128, out_channels=32, kernelSizes=[3, 3, 5, 5], dilatationRates=[1, 3, 3, 6], isDense=False + ) + + self.multiscale_block_2 = Multiscale_Conv_Block( + in_channels=256, out_channels=64, kernelSizes=[3, 3, 5, 5], dilatationRates=[1, 3, 2, 3], isDense=False + ) + + self.multiscale_block_3 = Multiscale_Conv_Block( + in_channels=64, out_channels=16, kernelSizes=[3, 3, 5, 7], dilatationRates=[1, 3, 2, 6], isDense=False + ) + + # -------------MaxPool2d blocks------------ + self.pool_block_1 = nn.MaxPool2d(kernel_size=(2, 2)) + self.pool_block_2 = nn.MaxPool2d(kernel_size=(2, 2)) + self.pool_block_3 = nn.MaxPool2d(kernel_size=(2, 2)) + self.pool_block_4 = nn.MaxPool2d(kernel_size=(2, 2)) + self.pool_block_5 = nn.MaxPool2d(kernel_size=(2, 2)) + + # -------------ConvTranspose2d blocks------------ + self.conv_transpose_1 = nn.ConvTranspose2d( + in_channels=1024, + out_channels=512, + kernel_size=2, + stride=(2, 2), + ) + + self.conv_transpose_2 = nn.ConvTranspose2d( + in_channels=256, + out_channels=256, + kernel_size=2, + stride=(2, 2), + ) + + self.conv_transpose_3 = nn.ConvTranspose2d( + in_channels=256, + out_channels=128, + kernel_size=2, + stride=(2, 2), + ) + + self.conv_transpose_4 = nn.ConvTranspose2d( + in_channels=128, + out_channels=64, + kernel_size=2, + stride=(2, 2), + ) + + self.conv_transpose_5 = nn.ConvTranspose2d( + in_channels=64, + out_channels=32, + kernel_size=2, + stride=(2, 2), + ) + + def forward(self, input): + conv1 = self.conv_block_1(input) # conv1: 32 channels + pool1 = self.pool_block_1(conv1) # poo1: 32 channels + + conv2 = self.residual_block_1(pool1) # conv2: 64 channels + pool2 = self.pool_block_2(conv2) # pool2: 64 channels + + conv3 = self.residual_block_2(pool2) # conv3: 128 channels + conv3 = self.multiscale_block_1(conv3) # conv3: 128 channels + conv3 = self.residual_block_3(conv3) # conv3: 128 channels + pool3 = self.pool_block_3(conv3) # pool3: 128 channels + + conv4 = self.residual_block_4(pool3) # conv4: 256 channels + pool4 = self.pool_block_4(conv4) # pool4: 512 channels + + conv5 = self.residual_block_5(pool4) # conv5: 512 channels + pool5 = self.pool_block_5(conv5) # pool5: 512 channels + + conv51 = self.residual_block_6(pool5) # conv51: 1024 channels + + up61 = torch.cat([self.conv_transpose_1(conv51), conv5], dim=1) # up61: 1024 channels + conv61 = self.residual_block_7(up61) # conv61: 256 channels + + up6 = torch.cat([self.conv_transpose_2(conv61), conv4], dim=1) # up6: 512 channels + conv6 = self.residual_block_8(up6) # conv6: 256 channels + conv6 = self.multiscale_block_2(conv6) # conv6: 256 channels + conv6 = self.residual_block_9(conv6) # conv6: 256 channels + + up7 = torch.cat([self.conv_transpose_3(conv6), conv3], dim=1) # up7: 256 channels + conv7 = self.residual_block_10(up7) # conv7: 128 channels + + up8 = torch.cat([self.conv_transpose_4(conv7), conv2], dim=1) # up8: 128 channels + conv8 = self.residual_block_11(up8) # conv8: 64 channels + conv8 = self.multiscale_block_3(conv8) # conv8: 64 channels + conv8 = self.residual_block_12(conv8) # conv8: 64 channels + + up9 = torch.cat([self.conv_transpose_5(conv8), conv1], dim=1) # up9: 64 channels + conv9 = self.conv_block_2(up9) # conv9: 32 channels + + conv10 = self.conv_block_3(conv9) # conv10: out_channels + + return conv10 diff --git a/sample-apps/pathology/lib/trainers/__init__.py b/sample-apps/pathology/lib/trainers/__init__.py index 6463e5137..438bef6d5 100644 --- a/sample-apps/pathology/lib/trainers/__init__.py +++ b/sample-apps/pathology/lib/trainers/__init__.py @@ -10,4 +10,5 @@ # limitations under the License. from .deepedit_nuclei import DeepEditNuclei +from .nuclick import NuClick from .segmentation_nuclei import SegmentationNuclei diff --git a/sample-apps/pathology/lib/trainers/nuclick.py b/sample-apps/pathology/lib/trainers/nuclick.py new file mode 100644 index 000000000..1ad2ac931 --- /dev/null +++ b/sample-apps/pathology/lib/trainers/nuclick.py @@ -0,0 +1,357 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os + +import cv2 +import numpy as np +import skimage +import torch +from ignite.metrics import Accuracy +from lib.handlers import TensorBoardImageHandler +from lib.transforms import FilterImaged +from lib.utils import split_dataset +from monai.apps.deepgrow.transforms import AddGuidanceSignald, AddRandomGuidanced, FindDiscrepancyRegionsd +from monai.handlers import from_engine +from monai.inferers import SimpleInferer +from monai.losses import DiceLoss +from monai.transforms import ( + Activationsd, + AddChanneld, + AsChannelFirstd, + AsDiscreted, + EnsureTyped, + LoadImaged, + RandRotate90d, + ScaleIntensityRangeD, + ToNumpyd, + TorchVisiond, + ToTensord, + Transform, +) + +from monailabel.interfaces.datastore import Datastore +from monailabel.tasks.train.basic_train import BasicTrainTask, Context + +logger = logging.getLogger(__name__) + + +class NuClick(BasicTrainTask): + def __init__( + self, + model_dir, + network, + labels, + tile_size=(256, 256), + roi_size=(128, 128), + description="Pathology NuClick Segmentation", + **kwargs, + ): + self._network = network + self.labels = labels + self.tile_size = tile_size + self.roi_size = roi_size + super().__init__(model_dir, description, **kwargs) + + def network(self, context: Context): + return self._network + + def optimizer(self, context: Context): + return torch.optim.Adam(context.network.parameters(), 0.0001) + + def loss_function(self, context: Context): + return DiceLoss(sigmoid=True, squared_pred=True) + + def pre_process(self, request, datastore: Datastore): + self.cleanup(request) + + cache_dir = os.path.join(self.get_cache_dir(request), "train_ds") + source = request.get("dataset_source") + max_region = request.get("dataset_max_region", (10240, 10240)) + max_region = (max_region, max_region) if isinstance(max_region, int) else max_region[:2] + + return split_dataset( + datastore=datastore, + cache_dir=cache_dir, + source=source, + groups=self.labels, + tile_size=self.tile_size, + max_region=max_region, + limit=request.get("dataset_limit", 0), + randomize=request.get("dataset_randomize", True), + ) + + def get_click_transforms(self, context: Context): + return [ + Activationsd(keys="pred", sigmoid=True), + ToNumpyd(keys=("image", "label", "pred")), + FindDiscrepancyRegionsd(label="label", pred="pred", discrepancy="discrepancy"), + AddRandomGuidanced(guidance="guidance", discrepancy="discrepancy", probability="probability"), + AddGuidanceSignald(image="image", guidance="guidance", number_intensity_ch=3), + ToTensord(keys=("image", "label")), + ] + + def train_pre_transforms(self, context: Context): + return [ + LoadImaged(keys=("image", "label"), dtype=np.uint8), + FilterImaged(keys="image", min_size=5), + AsChannelFirstd(keys="image"), + AddChanneld(keys="label"), + ToTensord(keys="image"), + TorchVisiond( + keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04 + ), + ToNumpyd(keys="image"), + RandRotate90d(keys=("image", "label"), prob=0.5, spatial_axes=(0, 1)), + ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), + ExtractPatchd(), + EnsureTyped(keys=("image", "label")), + ] + + def train_post_transforms(self, context: Context): + return [ + Activationsd(keys="pred", sigmoid=True), + AsDiscreted(keys="pred", threshold_values=True, logit_thresh=0.5), + ] + + def train_key_metric(self, context: Context): + return {"train_acc": Accuracy(output_transform=from_engine(["pred", "label"]))} + + def val_key_metric(self, context: Context): + return {"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))} + + def val_inferer(self, context: Context): + return SimpleInferer() + + def train_handlers(self, context: Context): + handlers = super().train_handlers(context) + if context.local_rank == 0: + handlers.append(TensorBoardImageHandler(log_dir=context.events_dir, batch_limit=4)) + return handlers + + +class ExtractPatchd(Transform): + def __init__(self, image="image", label="label", centroid="centroid"): + self.image = image + self.label = label + self.centroid = centroid + + self.patch_size = 128 + self.perturb = "distance" + self.drop_rate = 0.5 + self.jitter_range = 3 + + def __call__(self, data): + d = dict(data) + + img = d[self.image] + mask = d[self.label][0] + y, x = d[self.centroid] + m, n = mask.shape[:2] + + x_start = int(max(x - self.patch_size / 2, 0)) + y_start = int(max(y - self.patch_size / 2, 0)) + x_end = x_start + self.patch_size + y_end = y_start + self.patch_size + if x_end > n: + x_end = n + x_start = n - self.patch_size + if y_end > m: + y_end = m + y_start = m - self.patch_size + + mask_val = mask[y, x] + + mask_patch = mask[y_start:y_end, x_start:x_end] + img_patch = img[:, y_start:y_end, x_start:x_end] + + mask_patch_in = (mask_patch == mask_val).astype(np.uint8) + others_patch_in = (1 - mask_patch_in) * mask_patch + others_patch_in = self.mask_relabeling(others_patch_in, size_limit=5).astype(np.uint8) + + pad_size = (self.patch_size, self.patch_size) + img_patch = self.pad_to_shape(img_patch, pad_size, False) + mask_patch_in = self.pad_to_shape(mask_patch_in, pad_size, True) + others_patch_in = self.pad_to_shape(others_patch_in, pad_size, True) + + # create the guiding signals + signal_gen = PointGuidingSignal(mask_patch_in, others_patch_in, perturb=self.perturb) + inc_signal = signal_gen.inclusion_map() + exc_signal = signal_gen.exclusion_map(random_drop=self.drop_rate, random_jitter=self.jitter_range) + + image_patch = np.concatenate((img_patch, inc_signal[np.newaxis, ...], exc_signal[np.newaxis, ...]), axis=0) + d[self.image] = image_patch + d[self.label] = mask_patch_in[np.newaxis] + return d + + @staticmethod + def pad_to_shape(img, shape, is_mask): + img_shape = img.shape[-2:] + shape_diff = np.array(shape) - np.array(img_shape) + if is_mask: + img_padded = np.pad(img, [(0, shape_diff[0]), (0, shape_diff[1])], mode="constant", constant_values=0) + else: + img_padded = np.pad( + img, [(0, 0), (0, shape_diff[0]), (0, shape_diff[1])], mode="constant", constant_values=0 + ) + return img_padded + + @staticmethod + def mask_relabeling(mask, size_limit=5): + out_mask = np.zeros_like(mask, dtype=np.uint16) + unique_labels = np.unique(mask) + if unique_labels[0] == 0: + unique_labels = np.delete(unique_labels, 0) + + i = 1 + for l in unique_labels: + m = skimage.measure.label(mask == l, connectivity=1) + stats = skimage.measure.regionprops(m) + for stat in stats: + if stat.area > size_limit: + out_mask[stat.coords[:, 0], stat.coords[:, 1]] = i + i += 1 + return out_mask + + +def adaptive_distance_thresholding(mask): + """Refining the input mask using adaptive distance thresholding. + + Distance map of the input mask is generated and the an adaptive + (random) threshold based on the distance map is calculated to + generate a new mask from distance map based on it. + + Inputs: + mask (::np.ndarray::): Should be a 2D binary numpy array (uint8) + Outputs: + new_mask (::np.ndarray::): the refined mask + dist (::np.ndarray::): the distance map + """ + dist = cv2.distanceTransform(mask, cv2.DIST_L2, 0) + tempMean = np.mean(dist[dist > 0]) + tempStd = np.std(dist[dist > 0]) + tempTol = tempStd / 2 + low_thresh = np.max([tempMean - tempTol, 0]) + high_thresh = np.min([tempMean + tempTol, np.max(dist) - tempTol]) + if low_thresh >= high_thresh: + thresh = tempMean + else: + thresh = np.random.uniform(low_thresh, high_thresh) + new_mask = dist > thresh + if np.all(new_mask == np.zeros_like(new_mask)): + new_mask = dist > tempMean + return new_mask, dist + + +class GuidingSignal: + """A generic class for defining guiding signal generators. + + This class include some special methods that inclusion and exclusion guiding signals + for different application can be created based on. + """ + + def __init__(self, mask: np.ndarray, others: np.ndarray, kernel_size: int = 0) -> None: + self.mask = self.mask_validator(mask > 0.5) + self.kernel_size = kernel_size + if kernel_size: + self.current_mask = self.mask_preprocess(self.mask, kernel_size=self.kernel_size) + else: + self.current_mask = self.mask_validator(mask > 0.5) + self.others = others + + @staticmethod + def mask_preprocess(mask, kernel_size=3): + kernel = np.ones((kernel_size, kernel_size), np.uint8) + mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) + if np.all(mask == np.zeros_like(mask)): + logging.warning( + f"The kernel_size (radius) of {kernel_size} may be too high, consider checking " + "the intermediate output for the sanity of generated masks." + ) + return mask + + @staticmethod + def mask_validator(mask): + """Validate the input mask be np.uint8 and 2D""" + assert len(mask.shape) == 2, "Mask must be a 2D array (NxM)" + if not issubclass(type(mask[0, 0]), np.integer): + mask = np.uint8(mask) + return mask + + def inclusion_map(self): + """A function to generate inclusion gioding signal""" + raise NotImplementedError + + def exclusion_map(self): + """A function to generate exclusion gioding signal""" + raise NotImplementedError + + +class PointGuidingSignal(GuidingSignal): + def __init__(self, mask: np.ndarray, others: np.ndarray, perturb: str = "None", **kwargs) -> None: + super().__init__(mask, others, **kwargs) + if perturb.lower() not in {"none", "distance", "inside"}: + raise ValueError( + f'Invalid running perturb type of: {perturb}. Perturn type should be `"None"`, `"inside"`, or `"distance"`.' + ) + self.perturb = perturb.lower() + + def inclusion_map(self): + if self.perturb is None: # if there is no purturbation + indices = np.argwhere(self.current_mask == 1) # + centroid = np.mean(indices, axis=0) + pointMask = np.zeros_like(self.current_mask) + pointMask[int(centroid[0]), int(centroid[1]), 0] = 1 + return pointMask, self.current_mask + elif self.perturb == "distance" and np.any(self.current_mask > 0): + new_mask, _ = adaptive_distance_thresholding(self.current_mask) + else: # if self.perturb=='inside': + new_mask = self.current_mask.copy() + + # Creating the point map + pointMask = np.zeros_like(self.current_mask) + indices = np.argwhere(new_mask == 1) + if len(indices) > 0: + rndIdx = np.random.randint(0, len(indices)) + rndX = indices[rndIdx, 1] + rndY = indices[rndIdx, 0] + pointMask[rndY, rndX] = 1 + + return pointMask + + def exclusion_map(self, random_drop=0.0, random_jitter=0): + _, _, _, centroids = cv2.connectedComponentsWithStats(self.others, 4, cv2.CV_32S) + + centroids = centroids[1:, :] # removing the first centroid, it's background + if random_jitter: + centroids = self.jitterClicks(self.current_mask.shape, centroids, jitter_range=random_jitter) + if random_drop: # randomly dropping some of the points + drop_prob = np.random.uniform(0, random_drop) + num_select = int((1 - drop_prob) * centroids.shape[0]) + select_indices = np.random.choice(centroids.shape[0], size=num_select, replace=False) + centroids = centroids[select_indices, :] + centroids = np.int64(np.floor(centroids)) + + # create the point map + pointMask = np.zeros_like(self.others) + pointMask[centroids[:, 1], centroids[:, 0]] = 1 + + return pointMask + + @staticmethod + def jitterClicks(shape, centroids, jitter_range=3): + """Randomly jitter the centroid points + Points should be an array in (x, y) format while shape is (H, W) of the point map + """ + centroids += np.random.uniform(low=-jitter_range, high=jitter_range, size=centroids.shape) + centroids[:, 0] = np.clip(centroids[:, 0], 0, shape[1] - 1) + centroids[:, 1] = np.clip(centroids[:, 1], 0, shape[0] - 1) + return centroids diff --git a/sample-apps/pathology/lib/utils.py b/sample-apps/pathology/lib/utils.py index 3b898730a..b1b66d914 100644 --- a/sample-apps/pathology/lib/utils.py +++ b/sample-apps/pathology/lib/utils.py @@ -1,25 +1,31 @@ +import copy import logging import os import random import shutil -import xml.etree.ElementTree as ET +import xml.etree.ElementTree from io import BytesIO from math import ceil import cv2 import numpy as np import openslide +from monai.transforms import LoadImage from PIL import Image +from skimage.measure import regionprops from tqdm import tqdm from monailabel.datastore.dsa import DSADatastore from monailabel.datastore.local import LocalDatastore +from monailabel.interfaces.datastore import Datastore from monailabel.utils.others.generic import get_basename logger = logging.getLogger(__name__) -def split_dataset(datastore, cache_dir, source, groups, tile_size, max_region=(10240, 10240), limit=0, randomize=True): +def split_dataset( + datastore: Datastore, cache_dir, source, groups, tile_size, max_region=(10240, 10240), limit=0, randomize=True +): ds = datastore.datalist() shutil.rmtree(cache_dir, ignore_errors=True) @@ -28,6 +34,15 @@ def split_dataset(datastore, cache_dir, source, groups, tile_size, max_region=(1 if image is not None and len(image.shape) > 3: logger.info(f"PANNuke (For Developer Mode only):: Split data; groups: {groups}") ds = split_pannuke_dataset(ds[0]["image"], ds[0]["label"], cache_dir, groups) + elif source == "nuclick": + logger.info("Split data based on each nuclei") + ds_new = [] + for d in tqdm(ds): + ds_new.extend(split_nuclei_dataset(d)) + if 0 < limit < len(ds_new): + ds_new = ds_new[:limit] + break + ds = ds_new else: logger.info(f"Split data based on tile size: {tile_size}; groups: {groups}") ds_new = [] @@ -156,7 +171,7 @@ def split_local_dataset(datastore, d, output_dir, groups, tile_size, max_region= points = [] polygons = {g: [] for g in groups} - annotations_xml = ET.parse(d["label"]).getroot() + annotations_xml = xml.etree.ElementTree.parse(d["label"]).getroot() for annotation in annotations_xml.iter("Annotation"): g = annotation.get("PartOfGroup") g = g if g else "None" @@ -183,6 +198,22 @@ def split_local_dataset(datastore, d, output_dir, groups, tile_size, max_region= return dataset_json +def split_nuclei_dataset(d, centroid_key="centroid"): + dataset_json = [] + mask = LoadImage(image_only=True, dtype=np.uint8)(d["label"]) + stats = regionprops(mask) + for stat in stats: + y, x = stat.centroid + y = int(np.floor(y)) + x = int(np.floor(x)) + + if mask[y, x]: + item = copy.deepcopy(d) + item[centroid_key] = (y, x) + dataset_json.append(item) + return dataset_json + + def _group_item(groups, d, output_dir): groups = groups if groups else dict() groups = [groups] if isinstance(groups, str) else groups @@ -310,7 +341,7 @@ def main_dsa(): datastore = DSADatastore(api_url, folder, api_key, annotation_groups, asset_store_path) print(json.dumps(datastore.datalist(), indent=2)) - split_dataset(datastore, "/localhome/sachi/Downloads/dsa/mostly_tumor", None, annotation_groups, (256, 256)) + split_dataset(datastore, "/localhome/sachi/Downloads/dsa/mostly_tumor", "", annotation_groups, (256, 256)) def main_nuke(): @@ -341,9 +372,22 @@ def main_local(): datastore = LocalDatastore("C:\\Projects\\Pathology\\Test", extensions=("*.svs", "*.xml")) print(json.dumps(datastore.datalist(), indent=2)) - split_dataset(datastore, "C:\\Projects\\Pathology\\TestF", None, annotation_groups, (256, 256)) + split_dataset(datastore, "C:\\Projects\\Pathology\\TestF", "", annotation_groups, (256, 256)) # print(json.dumps(ds, indent=2)) +def main_nuclei(): + from monailabel.datastore.local import LocalDatastore + + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] [%(process)s] [%(threadName)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + datastore = LocalDatastore("/localhome/sachi/Datasets/NuClick", extensions=("*.png", "*.npy")) + split_dataset(datastore, None, "nuclick", None, None, limit=0) + + if __name__ == "__main__": - main_local() + main_nuclei() diff --git a/sample-apps/pathology/main.py b/sample-apps/pathology/main.py index 0688b3fe5..7716557b6 100644 --- a/sample-apps/pathology/main.py +++ b/sample-apps/pathology/main.py @@ -160,14 +160,16 @@ def main(): datefmt="%Y-%m-%d %H:%M:%S", ) - run_train = False + run_train = True home = str(Path.home()) if run_train: + studies = f"{home}/Datasets/NuClick" # studies = f"{home}/Data/Pathology/PanNuke" - studies = "http://0.0.0.0:8080/api/v1" + # studies = "http://0.0.0.0:8080/api/v1" else: - # studies = f"{home}/Data/Pathology/Test" - studies = "C:\\Projects\\Pathology\\Test" + # studies = f"{home}/Datasets/Pathology/" + # studies = "C:\\Projects\\Pathology\\Test" + studies = f"{home}/Datasets/NuClick" parser = argparse.ArgumentParser() parser.add_argument("-s", "--studies", default=studies) @@ -176,12 +178,30 @@ def main(): app_dir = os.path.dirname(__file__) studies = args.studies - app = MyApp(app_dir, studies, {}) + app = MyApp(app_dir, studies, {"roi_size": "[1024,1024]", "preload": "false"}) if run_train: - train(app) + train_nuclick(app) else: - infer_nuclick(app) - # infer_wsi(app) + # infer_nuclick(app) + infer_wsi(app) + + +def train_nuclick(app): + model = "nuclick" + app.train( + request={ + "name": "train_01", + "model": model, + "max_epochs": 10, + "dataset": "CacheDataset", # PersistentDataset, CacheDataset + "train_batch_size": 64, + "val_batch_size": 32, + "multi_gpu": False, + "val_split": 0.2, + "dataset_source": "nuclick", + "dataset_limit": 10, + }, + ) def train(app): @@ -241,10 +261,10 @@ def infer_wsi(app): home = str(Path.home()) - root_dir = f"{home}/Data/Pathology" + root_dir = f"{home}/Datasets/" image = "TCGA-02-0010-01Z-00-DX4.07de2e55-a8fe-40ee-9e98-bcb78050b9f7" - output = "dsa" + output = "asap" # slide = openslide.OpenSlide(f"{app.studies}/{image}.svs") # img = slide.read_region((7737, 20086), 0, (2048, 2048)).convert("RGB") @@ -259,7 +279,7 @@ def infer_wsi(app): "level": 0, "location": [0, 0], "size": [0, 0], - "tile_size": [2048, 2048], + "tile_size": [1024, 1024], "min_poly_area": 40, "gpus": "all", "multi_gpu": True, From 692a1efeffdd4f3a6982fda65d0fd384d682cf46 Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Fri, 6 May 2022 13:42:16 -0700 Subject: [PATCH 2/7] Sync up changes for nuclick training Signed-off-by: Sachidanand Alle --- monailabel/tasks/train/basic_train.py | 2 +- sample-apps/pathology/lib/handlers.py | 4 +- sample-apps/pathology/lib/trainers/nuclick.py | 300 ++++++++++-------- sample-apps/pathology/main.py | 4 +- 4 files changed, 169 insertions(+), 141 deletions(-) diff --git a/monailabel/tasks/train/basic_train.py b/monailabel/tasks/train/basic_train.py index 8dd2d205e..1ea36d4da 100644 --- a/monailabel/tasks/train/basic_train.py +++ b/monailabel/tasks/train/basic_train.py @@ -340,7 +340,7 @@ def config(self): @staticmethod def _validate_transforms(transforms, step="Training", name="pre"): - if not transforms or isinstance(transforms, Compose): + if not transforms or isinstance(transforms, Compose) or callable(transforms): return transforms if isinstance(transforms, list): return Compose(transforms) diff --git a/sample-apps/pathology/lib/handlers.py b/sample-apps/pathology/lib/handlers.py index 33f19154b..ed32a7e2d 100644 --- a/sample-apps/pathology/lib/handlers.py +++ b/sample-apps/pathology/lib/handlers.py @@ -141,7 +141,7 @@ def write_images(self, batch_data, output_data, epoch): label[y == region] = region self.logger.info( - "{} - {} - Image: {}; Label: {} (nz: {}); Pred: {} (nz: {})".format( + "{} - {} - Image: {}; Label: {} (nz: {}); Pred: {} (nz: {}); (pos-nz: {}); (neg-nz: {})".format( bidx, region, image.shape, @@ -149,6 +149,8 @@ def write_images(self, batch_data, output_data, epoch): np.count_nonzero(label), y_pred.shape, np.count_nonzero(y_pred[region]), + np.count_nonzero(image[-2]), + np.count_nonzero(image[-1]), ) ) diff --git a/sample-apps/pathology/lib/trainers/nuclick.py b/sample-apps/pathology/lib/trainers/nuclick.py index 1ad2ac931..030ef8e3d 100644 --- a/sample-apps/pathology/lib/trainers/nuclick.py +++ b/sample-apps/pathology/lib/trainers/nuclick.py @@ -11,32 +11,18 @@ import logging import os +import albumentations as alb import cv2 import numpy as np import skimage import torch from ignite.metrics import Accuracy from lib.handlers import TensorBoardImageHandler -from lib.transforms import FilterImaged from lib.utils import split_dataset -from monai.apps.deepgrow.transforms import AddGuidanceSignald, AddRandomGuidanced, FindDiscrepancyRegionsd from monai.handlers import from_engine from monai.inferers import SimpleInferer from monai.losses import DiceLoss -from monai.transforms import ( - Activationsd, - AddChanneld, - AsChannelFirstd, - AsDiscreted, - EnsureTyped, - LoadImaged, - RandRotate90d, - ScaleIntensityRangeD, - ToNumpyd, - TorchVisiond, - ToTensord, - Transform, -) +from monai.transforms import Activationsd, AsDiscreted, LoadImaged, RandomizableTransform, Transform from monailabel.interfaces.datastore import Datastore from monailabel.tasks.train.basic_train import BasicTrainTask, Context @@ -89,31 +75,12 @@ def pre_process(self, request, datastore: Datastore): randomize=request.get("dataset_randomize", True), ) - def get_click_transforms(self, context: Context): - return [ - Activationsd(keys="pred", sigmoid=True), - ToNumpyd(keys=("image", "label", "pred")), - FindDiscrepancyRegionsd(label="label", pred="pred", discrepancy="discrepancy"), - AddRandomGuidanced(guidance="guidance", discrepancy="discrepancy", probability="probability"), - AddGuidanceSignald(image="image", guidance="guidance", number_intensity_ch=3), - ToTensord(keys=("image", "label")), - ] - def train_pre_transforms(self, context: Context): return [ LoadImaged(keys=("image", "label"), dtype=np.uint8), - FilterImaged(keys="image", min_size=5), - AsChannelFirstd(keys="image"), - AddChanneld(keys="label"), - ToTensord(keys="image"), - TorchVisiond( - keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04 - ), - ToNumpyd(keys="image"), - RandRotate90d(keys=("image", "label"), prob=0.5, spatial_axes=(0, 1)), - ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), ExtractPatchd(), - EnsureTyped(keys=("image", "label")), + Augmentd(), + AddSignald(), ] def train_post_transforms(self, context: Context): @@ -139,22 +106,15 @@ def train_handlers(self, context: Context): class ExtractPatchd(Transform): - def __init__(self, image="image", label="label", centroid="centroid"): - self.image = image - self.label = label - self.centroid = centroid - + def __init__(self): self.patch_size = 128 - self.perturb = "distance" - self.drop_rate = 0.5 - self.jitter_range = 3 def __call__(self, data): d = dict(data) - img = d[self.image] - mask = d[self.label][0] - y, x = d[self.centroid] + img = d["image"] + mask = d["label"] + y, x = d["centroid"] m, n = mask.shape[:2] x_start = int(max(x - self.patch_size / 2, 0)) @@ -170,42 +130,25 @@ def __call__(self, data): mask_val = mask[y, x] - mask_patch = mask[y_start:y_end, x_start:x_end] - img_patch = img[:, y_start:y_end, x_start:x_end] + img_patch = img[y_start:y_end, x_start:x_end, :] + mask_p = mask[y_start:y_end, x_start:x_end] - mask_patch_in = (mask_patch == mask_val).astype(np.uint8) - others_patch_in = (1 - mask_patch_in) * mask_patch - others_patch_in = self.mask_relabeling(others_patch_in, size_limit=5).astype(np.uint8) + mask_patch = (mask_p == mask_val).astype(np.uint8) + others_patch = (1 - mask_patch) * mask_p + others_patch = self._mask_relabeling(others_patch, size_limit=5).astype(np.uint8) pad_size = (self.patch_size, self.patch_size) img_patch = self.pad_to_shape(img_patch, pad_size, False) - mask_patch_in = self.pad_to_shape(mask_patch_in, pad_size, True) - others_patch_in = self.pad_to_shape(others_patch_in, pad_size, True) + mask_patch = self.pad_to_shape(mask_patch, pad_size, True) + others_patch = self.pad_to_shape(others_patch, pad_size, True) - # create the guiding signals - signal_gen = PointGuidingSignal(mask_patch_in, others_patch_in, perturb=self.perturb) - inc_signal = signal_gen.inclusion_map() - exc_signal = signal_gen.exclusion_map(random_drop=self.drop_rate, random_jitter=self.jitter_range) - - image_patch = np.concatenate((img_patch, inc_signal[np.newaxis, ...], exc_signal[np.newaxis, ...]), axis=0) - d[self.image] = image_patch - d[self.label] = mask_patch_in[np.newaxis] + d["image"] = img_patch + d["label"] = mask_patch + d["others"] = others_patch return d @staticmethod - def pad_to_shape(img, shape, is_mask): - img_shape = img.shape[-2:] - shape_diff = np.array(shape) - np.array(img_shape) - if is_mask: - img_padded = np.pad(img, [(0, shape_diff[0]), (0, shape_diff[1])], mode="constant", constant_values=0) - else: - img_padded = np.pad( - img, [(0, 0), (0, shape_diff[0]), (0, shape_diff[1])], mode="constant", constant_values=0 - ) - return img_padded - - @staticmethod - def mask_relabeling(mask, size_limit=5): + def _mask_relabeling(mask, size_limit=5): out_mask = np.zeros_like(mask, dtype=np.uint16) unique_labels = np.unique(mask) if unique_labels[0] == 0: @@ -221,51 +164,123 @@ def mask_relabeling(mask, size_limit=5): i += 1 return out_mask + @staticmethod + def pad_to_shape(img, shape, is_mask): + img_shape = img.shape[:2] + s_diff = np.array(shape) - np.array(img_shape) + diff = [(0, s_diff[0]), (0, s_diff[1])] + if not is_mask: + diff.append((0, 0)) + + return np.pad( + img, + diff, + mode="constant", + constant_values=0, + ) + + +class Augmentd(RandomizableTransform): + def __init__(self): + self.train_augs = alb.Compose( + [ + alb.OneOf( + [ + alb.HueSaturationValue( + hue_shift_limit=10, + sat_shift_limit=(-30, 20), + val_shift_limit=0, + always_apply=False, + p=0.75, + ), # .8 + alb.RGBShift( + r_shift_limit=20, + g_shift_limit=20, + b_shift_limit=20, + p=0.75, + ), # .7 + ], + p=1.0, + ), + alb.OneOf( + [ + alb.GaussianBlur(blur_limit=(3, 5), p=0.5), + alb.Sharpen(alpha=(0.1, 0.3), lightness=(1.0, 1.0), p=0.5), + alb.ImageCompression(quality_lower=30, quality_upper=80, p=0.5), + ], + p=1.0, + ), + alb.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), + alb.ShiftScaleRotate( + shift_limit=0.1, + scale_limit=0.2, + rotate_limit=180, + border_mode=cv2.BORDER_CONSTANT, + value=0, + p=0.5, + ), + alb.Flip(p=0.5), + ], + additional_targets={"others": "mask"}, + p=0.5, + ) + + def __call__(self, data): + d = dict(data) + img = d["image"] + mask = d["label"] + others = d["others"] + + augmented_data = self.train_augs(image=img, mask=mask, others=others) + d["image"] = augmented_data["image"] + d["label"] = augmented_data["mask"] + d["others"] = augmented_data["others"] + return d + + +class AddSignald(RandomizableTransform): + def __init__(self): + self.perturb = "distance" + self.drop_rate = 0.5 + self.jitter_range = 3 + + def __call__(self, data): + d = dict(data) + + img = d["image"] + mask = d["label"] + others = d["others"] + + # Transform-x: create the guiding signals + signal_gen = PointGuidingSignal(mask, others, perturb=self.perturb) + inc_sig = signal_gen.inclusion_map() + exc_sig = signal_gen.exclusion_map(random_drop=self.drop_rate, random_jitter=self.jitter_range) + + img = np.float32(img) / 255.0 + img = np.moveaxis(img, -1, 0) -def adaptive_distance_thresholding(mask): - """Refining the input mask using adaptive distance thresholding. - - Distance map of the input mask is generated and the an adaptive - (random) threshold based on the distance map is calculated to - generate a new mask from distance map based on it. - - Inputs: - mask (::np.ndarray::): Should be a 2D binary numpy array (uint8) - Outputs: - new_mask (::np.ndarray::): the refined mask - dist (::np.ndarray::): the distance map - """ - dist = cv2.distanceTransform(mask, cv2.DIST_L2, 0) - tempMean = np.mean(dist[dist > 0]) - tempStd = np.std(dist[dist > 0]) - tempTol = tempStd / 2 - low_thresh = np.max([tempMean - tempTol, 0]) - high_thresh = np.min([tempMean + tempTol, np.max(dist) - tempTol]) - if low_thresh >= high_thresh: - thresh = tempMean - else: - thresh = np.random.uniform(low_thresh, high_thresh) - new_mask = dist > thresh - if np.all(new_mask == np.zeros_like(new_mask)): - new_mask = dist > tempMean - return new_mask, dist - - -class GuidingSignal: - """A generic class for defining guiding signal generators. - - This class include some special methods that inclusion and exclusion guiding signals - for different application can be created based on. - """ - - def __init__(self, mask: np.ndarray, others: np.ndarray, kernel_size: int = 0) -> None: + img = np.concatenate((img, inc_sig[np.newaxis, ...], exc_sig[np.newaxis, ...]), axis=0) + d["image"] = torch.as_tensor(img.copy()).float().contiguous() + d["label"] = torch.as_tensor(mask[np.newaxis, ...].copy()).long().contiguous() + return d + + +class PointGuidingSignal: + def __init__(self, mask: np.ndarray, others: np.ndarray, kernel_size=0, perturb: str = "None") -> None: self.mask = self.mask_validator(mask > 0.5) - self.kernel_size = kernel_size - if kernel_size: - self.current_mask = self.mask_preprocess(self.mask, kernel_size=self.kernel_size) - else: - self.current_mask = self.mask_validator(mask > 0.5) self.others = others + self.kernel_size = kernel_size + self.current_mask = ( + self.mask_preprocess(self.mask, kernel_size=self.kernel_size) + if kernel_size + else self.mask_validator(mask > 0.5) + ) + + if perturb.lower() not in {"none", "distance", "inside"}: + raise ValueError( + f'Invalid running perturb type of: {perturb}. Perturn type should be `"None"`, `"inside"`, or `"distance"`.' + ) + self.perturb = perturb.lower() @staticmethod def mask_preprocess(mask, kernel_size=3): @@ -286,24 +301,6 @@ def mask_validator(mask): mask = np.uint8(mask) return mask - def inclusion_map(self): - """A function to generate inclusion gioding signal""" - raise NotImplementedError - - def exclusion_map(self): - """A function to generate exclusion gioding signal""" - raise NotImplementedError - - -class PointGuidingSignal(GuidingSignal): - def __init__(self, mask: np.ndarray, others: np.ndarray, perturb: str = "None", **kwargs) -> None: - super().__init__(mask, others, **kwargs) - if perturb.lower() not in {"none", "distance", "inside"}: - raise ValueError( - f'Invalid running perturb type of: {perturb}. Perturn type should be `"None"`, `"inside"`, or `"distance"`.' - ) - self.perturb = perturb.lower() - def inclusion_map(self): if self.perturb is None: # if there is no purturbation indices = np.argwhere(self.current_mask == 1) # @@ -312,7 +309,7 @@ def inclusion_map(self): pointMask[int(centroid[0]), int(centroid[1]), 0] = 1 return pointMask, self.current_mask elif self.perturb == "distance" and np.any(self.current_mask > 0): - new_mask, _ = adaptive_distance_thresholding(self.current_mask) + new_mask, _ = self.adaptive_distance_thresholding(self.current_mask) else: # if self.perturb=='inside': new_mask = self.current_mask.copy() @@ -355,3 +352,32 @@ def jitterClicks(shape, centroids, jitter_range=3): centroids[:, 0] = np.clip(centroids[:, 0], 0, shape[1] - 1) centroids[:, 1] = np.clip(centroids[:, 1], 0, shape[0] - 1) return centroids + + @staticmethod + def adaptive_distance_thresholding(mask): + """Refining the input mask using adaptive distance thresholding. + + Distance map of the input mask is generated and the an adaptive + (random) threshold based on the distance map is calculated to + generate a new mask from distance map based on it. + + Inputs: + mask (::np.ndarray::): Should be a 2D binary numpy array (uint8) + Outputs: + new_mask (::np.ndarray::): the refined mask + dist (::np.ndarray::): the distance map + """ + dist = cv2.distanceTransform(mask, cv2.DIST_L2, 0) + tempMean = np.mean(dist[dist > 0]) + tempStd = np.std(dist[dist > 0]) + tempTol = tempStd / 2 + low_thresh = np.max([tempMean - tempTol, 0]) + high_thresh = np.min([tempMean + tempTol, np.max(dist) - tempTol]) + if low_thresh >= high_thresh: + thresh = tempMean + else: + thresh = np.random.uniform(low_thresh, high_thresh) + new_mask = dist > thresh + if np.all(new_mask == np.zeros_like(new_mask)): + new_mask = dist > tempMean + return new_mask, dist diff --git a/sample-apps/pathology/main.py b/sample-apps/pathology/main.py index 7716557b6..a414f4a45 100644 --- a/sample-apps/pathology/main.py +++ b/sample-apps/pathology/main.py @@ -192,14 +192,14 @@ def train_nuclick(app): request={ "name": "train_01", "model": model, - "max_epochs": 10, + "max_epochs": 100, "dataset": "CacheDataset", # PersistentDataset, CacheDataset "train_batch_size": 64, "val_batch_size": 32, "multi_gpu": False, "val_split": 0.2, "dataset_source": "nuclick", - "dataset_limit": 10, + "dataset_limit": 0, }, ) From 1373b782c85f9cc753d5ee13fc67975f82ae87a5 Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Sat, 7 May 2022 19:05:24 -0700 Subject: [PATCH 3/7] Fix nuclick training Signed-off-by: Sachidanand Alle --- sample-apps/pathology/lib/trainers/nuclick.py | 413 +++++++----------- sample-apps/pathology/lib/utils.py | 49 ++- sample-apps/pathology/main.py | 20 +- 3 files changed, 208 insertions(+), 274 deletions(-) diff --git a/sample-apps/pathology/lib/trainers/nuclick.py b/sample-apps/pathology/lib/trainers/nuclick.py index 030ef8e3d..919e467a6 100644 --- a/sample-apps/pathology/lib/trainers/nuclick.py +++ b/sample-apps/pathology/lib/trainers/nuclick.py @@ -8,21 +8,42 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging +import math import os +import random -import albumentations as alb import cv2 import numpy as np import skimage import torch from ignite.metrics import Accuracy from lib.handlers import TensorBoardImageHandler -from lib.utils import split_dataset +from lib.transforms import FilterImaged +from lib.utils import split_dataset, split_nuclei_dataset +from monai.config import KeysCollection from monai.handlers import from_engine from monai.inferers import SimpleInferer from monai.losses import DiceLoss -from monai.transforms import Activationsd, AsDiscreted, LoadImaged, RandomizableTransform, Transform +from monai.transforms import ( + Activationsd, + AddChanneld, + AsChannelFirstd, + AsDiscreted, + EnsureTyped, + LoadImaged, + MapTransform, + RandomizableTransform, + RandRotate90d, + ScaleIntensityRangeD, + ToNumpyd, + TorchVisiond, + ToTensord, + Transform, +) +from skimage.measure import regionprops +from tqdm import tqdm from monailabel.interfaces.datastore import Datastore from monailabel.tasks.train.basic_train import BasicTrainTask, Context @@ -37,14 +58,16 @@ def __init__( network, labels, tile_size=(256, 256), - roi_size=(128, 128), + patch_size=128, + min_area=5, description="Pathology NuClick Segmentation", **kwargs, ): self._network = network self.labels = labels self.tile_size = tile_size - self.roi_size = roi_size + self.patch_size = patch_size + self.min_area = min_area super().__init__(model_dir, description, **kwargs) def network(self, context: Context): @@ -64,7 +87,7 @@ def pre_process(self, request, datastore: Datastore): max_region = request.get("dataset_max_region", (10240, 10240)) max_region = (max_region, max_region) if isinstance(max_region, int) else max_region[:2] - return split_dataset( + ds = split_dataset( datastore=datastore, cache_dir=cache_dir, source=source, @@ -75,12 +98,34 @@ def pre_process(self, request, datastore: Datastore): randomize=request.get("dataset_randomize", True), ) + logger.info(f"Split data (len: {len(ds)}) based on each nuclei") + ds_new = [] + limit = request.get("dataset_limit", 0) + for d in tqdm(ds): + ds_new.extend(split_nuclei_dataset(d, min_area=self.min_area)) + if 0 < limit < len(ds_new): + ds_new = ds_new[:limit] + break + return ds_new + def train_pre_transforms(self, context: Context): return [ LoadImaged(keys=("image", "label"), dtype=np.uint8), - ExtractPatchd(), - Augmentd(), - AddSignald(), + FilterImaged(keys="image", min_size=5), + FlattenLabel(keys="label"), + AsChannelFirstd(keys="image"), + AddChanneld(keys="label"), + ExtractPatchd(keys=("image", "label"), patch_size=self.patch_size), + SplitLabeld(label="label", others="others", mask_value="mask_value", min_area=self.min_area), + ToTensord(keys="image"), + TorchVisiond( + keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04 + ), + ToNumpyd(keys="image"), + RandRotate90d(keys=("image", "label", "others"), prob=0.5, spatial_axes=(0, 1)), + ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), + AddPointGuidanceSignald(image="image", label="label", others="others"), + EnsureTyped(keys=("image", "label")), ] def train_post_transforms(self, context: Context): @@ -105,73 +150,56 @@ def train_handlers(self, context: Context): return handlers -class ExtractPatchd(Transform): - def __init__(self): - self.patch_size = 128 +class FlattenLabel(MapTransform): + def __call__(self, data): + d = dict(data) + for key in self.keys: + _, labels, _, _ = cv2.connectedComponentsWithStats(d[key], 4, cv2.CV_32S) + d[key] = labels.astype(np.uint8) + return d + + +class ExtractPatchd(MapTransform): + def __init__(self, keys: KeysCollection, centroid_key="centroid", patch_size=128): + super().__init__(keys) + self.centroid_key = centroid_key + self.patch_size = patch_size def __call__(self, data): d = dict(data) - img = d["image"] - mask = d["label"] - y, x = d["centroid"] - m, n = mask.shape[:2] - - x_start = int(max(x - self.patch_size / 2, 0)) - y_start = int(max(y - self.patch_size / 2, 0)) - x_end = x_start + self.patch_size - y_end = y_start + self.patch_size - if x_end > n: - x_end = n - x_start = n - self.patch_size - if y_end > m: - y_end = m - y_start = m - self.patch_size - - mask_val = mask[y, x] - - img_patch = img[y_start:y_end, x_start:x_end, :] - mask_p = mask[y_start:y_end, x_start:x_end] - - mask_patch = (mask_p == mask_val).astype(np.uint8) - others_patch = (1 - mask_patch) * mask_p - others_patch = self._mask_relabeling(others_patch, size_limit=5).astype(np.uint8) - - pad_size = (self.patch_size, self.patch_size) - img_patch = self.pad_to_shape(img_patch, pad_size, False) - mask_patch = self.pad_to_shape(mask_patch, pad_size, True) - others_patch = self.pad_to_shape(others_patch, pad_size, True) - - d["image"] = img_patch - d["label"] = mask_patch - d["others"] = others_patch + centroid = d[self.centroid_key] # create mask based on centroid (select nuclei based on centroid) + roi_size = (self.patch_size, self.patch_size) + + for key in self.keys: + img = d[key] + x_start, x_end, y_start, y_end = self.bbox(self.patch_size, centroid, img.shape[-2:]) + cropped = img[:, x_start:x_end, y_start:y_end] + d[key] = self.pad_to_shape(cropped, roi_size) return d @staticmethod - def _mask_relabeling(mask, size_limit=5): - out_mask = np.zeros_like(mask, dtype=np.uint16) - unique_labels = np.unique(mask) - if unique_labels[0] == 0: - unique_labels = np.delete(unique_labels, 0) - - i = 1 - for l in unique_labels: - m = skimage.measure.label(mask == l, connectivity=1) - stats = skimage.measure.regionprops(m) - for stat in stats: - if stat.area > size_limit: - out_mask[stat.coords[:, 0], stat.coords[:, 1]] = i - i += 1 - return out_mask + def bbox(patch_size, centroid, size): + x, y = centroid + m, n = size + + x_start = int(max(x - patch_size / 2, 0)) + y_start = int(max(y - patch_size / 2, 0)) + x_end = x_start + patch_size + y_end = y_start + patch_size + if x_end > m: + x_end = m + x_start = m - patch_size + if y_end > n: + y_end = n + y_start = n - patch_size + return x_start, x_end, y_start, y_end @staticmethod - def pad_to_shape(img, shape, is_mask): - img_shape = img.shape[:2] + def pad_to_shape(img, shape): + img_shape = img.shape[-2:] s_diff = np.array(shape) - np.array(img_shape) - diff = [(0, s_diff[0]), (0, s_diff[1])] - if not is_mask: - diff.append((0, 0)) - + diff = [(0, 0), (0, s_diff[0]), (0, s_diff[1])] return np.pad( img, diff, @@ -180,204 +208,95 @@ def pad_to_shape(img, shape, is_mask): ) -class Augmentd(RandomizableTransform): - def __init__(self): - self.train_augs = alb.Compose( - [ - alb.OneOf( - [ - alb.HueSaturationValue( - hue_shift_limit=10, - sat_shift_limit=(-30, 20), - val_shift_limit=0, - always_apply=False, - p=0.75, - ), # .8 - alb.RGBShift( - r_shift_limit=20, - g_shift_limit=20, - b_shift_limit=20, - p=0.75, - ), # .7 - ], - p=1.0, - ), - alb.OneOf( - [ - alb.GaussianBlur(blur_limit=(3, 5), p=0.5), - alb.Sharpen(alpha=(0.1, 0.3), lightness=(1.0, 1.0), p=0.5), - alb.ImageCompression(quality_lower=30, quality_upper=80, p=0.5), - ], - p=1.0, - ), - alb.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), - alb.ShiftScaleRotate( - shift_limit=0.1, - scale_limit=0.2, - rotate_limit=180, - border_mode=cv2.BORDER_CONSTANT, - value=0, - p=0.5, - ), - alb.Flip(p=0.5), - ], - additional_targets={"others": "mask"}, - p=0.5, - ) +class SplitLabeld(Transform): + def __init__(self, label="label", others="others", mask_value="mask_value", min_area=5): + self.label = label + self.others = others + self.mask_value = mask_value + self.min_area = min_area def __call__(self, data): d = dict(data) - img = d["image"] - mask = d["label"] - others = d["others"] - - augmented_data = self.train_augs(image=img, mask=mask, others=others) - d["image"] = augmented_data["image"] - d["label"] = augmented_data["mask"] - d["others"] = augmented_data["others"] - return d + label = d[self.label] + mask_value = d[self.mask_value] + mask = np.uint8(label == mask_value) + others = (1 - mask) * label + others = self._mask_relabeling(others[0], min_area=self.min_area)[np.newaxis] -class AddSignald(RandomizableTransform): - def __init__(self): - self.perturb = "distance" - self.drop_rate = 0.5 - self.jitter_range = 3 + d[self.label] = mask + d[self.others] = others + return d - def __call__(self, data): - d = dict(data) + @staticmethod + def _mask_relabeling(mask, min_area=5): + res = np.zeros_like(mask) + for l in np.unique(mask): + if l == 0: + continue - img = d["image"] - mask = d["label"] - others = d["others"] + m = skimage.measure.label(mask == l, connectivity=1) + for stat in skimage.measure.regionprops(m): + if stat.area > min_area: + res[stat.coords[:, 0], stat.coords[:, 1]] = l + return res - # Transform-x: create the guiding signals - signal_gen = PointGuidingSignal(mask, others, perturb=self.perturb) - inc_sig = signal_gen.inclusion_map() - exc_sig = signal_gen.exclusion_map(random_drop=self.drop_rate, random_jitter=self.jitter_range) - img = np.float32(img) / 255.0 - img = np.moveaxis(img, -1, 0) +class AddPointGuidanceSignald(RandomizableTransform): + def __init__( + self, image="image", label="label", others="others", perturb="distance", drop_rate=0.5, jitter_range=3 + ): + super().__init__() - img = np.concatenate((img, inc_sig[np.newaxis, ...], exc_sig[np.newaxis, ...]), axis=0) - d["image"] = torch.as_tensor(img.copy()).float().contiguous() - d["label"] = torch.as_tensor(mask[np.newaxis, ...].copy()).long().contiguous() - return d + self.image = image + self.label = label + self.others = others + self.perturb = perturb + self.drop_rate = drop_rate + self.jitter_range = jitter_range + def __call__(self, data): + d = dict(data) -class PointGuidingSignal: - def __init__(self, mask: np.ndarray, others: np.ndarray, kernel_size=0, perturb: str = "None") -> None: - self.mask = self.mask_validator(mask > 0.5) - self.others = others - self.kernel_size = kernel_size - self.current_mask = ( - self.mask_preprocess(self.mask, kernel_size=self.kernel_size) - if kernel_size - else self.mask_validator(mask > 0.5) - ) + image = d[self.image] + mask = d[self.label] + others = d[self.others] - if perturb.lower() not in {"none", "distance", "inside"}: - raise ValueError( - f'Invalid running perturb type of: {perturb}. Perturn type should be `"None"`, `"inside"`, or `"distance"`.' - ) - self.perturb = perturb.lower() + inc_sig = self.inclusion_map(mask[0]) + exc_sig = self.exclusion_map(others[0], drop_rate=self.drop_rate, jitter_range=self.jitter_range) - @staticmethod - def mask_preprocess(mask, kernel_size=3): - kernel = np.ones((kernel_size, kernel_size), np.uint8) - mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) - if np.all(mask == np.zeros_like(mask)): - logging.warning( - f"The kernel_size (radius) of {kernel_size} may be too high, consider checking " - "the intermediate output for the sanity of generated masks." - ) - return mask + image = np.concatenate((image, inc_sig[np.newaxis, ...], exc_sig[np.newaxis, ...]), axis=0) + d[self.image] = image + return d @staticmethod - def mask_validator(mask): - """Validate the input mask be np.uint8 and 2D""" - assert len(mask.shape) == 2, "Mask must be a 2D array (NxM)" - if not issubclass(type(mask[0, 0]), np.integer): - mask = np.uint8(mask) - return mask - - def inclusion_map(self): - if self.perturb is None: # if there is no purturbation - indices = np.argwhere(self.current_mask == 1) # - centroid = np.mean(indices, axis=0) - pointMask = np.zeros_like(self.current_mask) - pointMask[int(centroid[0]), int(centroid[1]), 0] = 1 - return pointMask, self.current_mask - elif self.perturb == "distance" and np.any(self.current_mask > 0): - new_mask, _ = self.adaptive_distance_thresholding(self.current_mask) - else: # if self.perturb=='inside': - new_mask = self.current_mask.copy() - - # Creating the point map - pointMask = np.zeros_like(self.current_mask) - indices = np.argwhere(new_mask == 1) + def inclusion_map(mask): + point_mask = np.zeros_like(mask) + indices = np.argwhere(mask > 0) if len(indices) > 0: - rndIdx = np.random.randint(0, len(indices)) - rndX = indices[rndIdx, 1] - rndY = indices[rndIdx, 0] - pointMask[rndY, rndX] = 1 - - return pointMask - - def exclusion_map(self, random_drop=0.0, random_jitter=0): - _, _, _, centroids = cv2.connectedComponentsWithStats(self.others, 4, cv2.CV_32S) + idx = np.random.randint(0, len(indices)) + point_mask[indices[idx, 0], indices[idx, 1]] = 1 - centroids = centroids[1:, :] # removing the first centroid, it's background - if random_jitter: - centroids = self.jitterClicks(self.current_mask.shape, centroids, jitter_range=random_jitter) - if random_drop: # randomly dropping some of the points - drop_prob = np.random.uniform(0, random_drop) - num_select = int((1 - drop_prob) * centroids.shape[0]) - select_indices = np.random.choice(centroids.shape[0], size=num_select, replace=False) - centroids = centroids[select_indices, :] - centroids = np.int64(np.floor(centroids)) - - # create the point map - pointMask = np.zeros_like(self.others) - pointMask[centroids[:, 1], centroids[:, 0]] = 1 - - return pointMask - - @staticmethod - def jitterClicks(shape, centroids, jitter_range=3): - """Randomly jitter the centroid points - Points should be an array in (x, y) format while shape is (H, W) of the point map - """ - centroids += np.random.uniform(low=-jitter_range, high=jitter_range, size=centroids.shape) - centroids[:, 0] = np.clip(centroids[:, 0], 0, shape[1] - 1) - centroids[:, 1] = np.clip(centroids[:, 1], 0, shape[0] - 1) - return centroids + return point_mask @staticmethod - def adaptive_distance_thresholding(mask): - """Refining the input mask using adaptive distance thresholding. - - Distance map of the input mask is generated and the an adaptive - (random) threshold based on the distance map is calculated to - generate a new mask from distance map based on it. - - Inputs: - mask (::np.ndarray::): Should be a 2D binary numpy array (uint8) - Outputs: - new_mask (::np.ndarray::): the refined mask - dist (::np.ndarray::): the distance map - """ - dist = cv2.distanceTransform(mask, cv2.DIST_L2, 0) - tempMean = np.mean(dist[dist > 0]) - tempStd = np.std(dist[dist > 0]) - tempTol = tempStd / 2 - low_thresh = np.max([tempMean - tempTol, 0]) - high_thresh = np.min([tempMean + tempTol, np.max(dist) - tempTol]) - if low_thresh >= high_thresh: - thresh = tempMean - else: - thresh = np.random.uniform(low_thresh, high_thresh) - new_mask = dist > thresh - if np.all(new_mask == np.zeros_like(new_mask)): - new_mask = dist > tempMean - return new_mask, dist + def exclusion_map(others, jitter_range=3, drop_rate=0.5): + point_mask = np.zeros_like(others) + max_x = point_mask.shape[0] - 1 + max_y = point_mask.shape[1] - 1 + + stats = regionprops(others) + for stat in stats: + x, y = stat.centroid + # random drop + if np.random.choice([True, False], p=[drop_rate, 1 - drop_rate]): + continue + + # random jitter + x = int(math.floor(x)) + random.randint(a=-jitter_range, b=jitter_range) + y = int(math.floor(y)) + random.randint(a=-jitter_range, b=jitter_range) + x = min(max(0, x), max_x) + y = min(max(0, y), max_y) + point_mask[x, y] = 1 + + return point_mask diff --git a/sample-apps/pathology/lib/utils.py b/sample-apps/pathology/lib/utils.py index b1b66d914..9dd35280b 100644 --- a/sample-apps/pathology/lib/utils.py +++ b/sample-apps/pathology/lib/utils.py @@ -1,5 +1,6 @@ import copy import logging +import math import os import random import shutil @@ -29,7 +30,9 @@ def split_dataset( ds = datastore.datalist() shutil.rmtree(cache_dir, ignore_errors=True) - if source == "pannuke": + if source == "none": + pass + elif source == "pannuke": image = np.load(ds[0]["image"]) if len(ds) == 1 else None if image is not None and len(image.shape) > 3: logger.info(f"PANNuke (For Developer Mode only):: Split data; groups: {groups}") @@ -198,19 +201,28 @@ def split_local_dataset(datastore, d, output_dir, groups, tile_size, max_region= return dataset_json -def split_nuclei_dataset(d, centroid_key="centroid"): +def split_nuclei_dataset(d, centroid_key="centroid", mask_value_key="mask_value", min_area=5): dataset_json = [] + mask = LoadImage(image_only=True, dtype=np.uint8)(d["label"]) - stats = regionprops(mask) + _, labels, _, _ = cv2.connectedComponentsWithStats(mask, 4, cv2.CV_32S) + + stats = regionprops(labels) for stat in stats: - y, x = stat.centroid - y = int(np.floor(y)) - x = int(np.floor(x)) - - if mask[y, x]: - item = copy.deepcopy(d) - item[centroid_key] = (y, x) - dataset_json.append(item) + if stat.area < min_area: + logger.info(f"++++ Ignored label with smaller area => ( {stat.area} < {min_area})") + continue + + x, y = stat.centroid + x = int(math.floor(x)) + y = int(math.floor(y)) + + item = copy.deepcopy(d) + item[centroid_key] = (x, y) + item[mask_value_key] = stat.label + + # logger.info(f"{d['label']} => {len(stats)} => {mask.shape} => {stat.label}") + dataset_json.append(item) return dataset_json @@ -353,8 +365,15 @@ def main_nuke(): datefmt="%Y-%m-%d %H:%M:%S", ) - datastore = LocalDatastore("/localhome/sachi/Data/Pathology/PanNuke", extensions=("*.nii.gz", "*.nii", "*.npy")) - split_dataset(datastore, "/localhome/sachi/Data/Pathology/PanNukeF", "pannuke", "Nuclei", None) + datastore = LocalDatastore("/localhome/sachi/Datasets/pannuke", extensions=("*.npy")) + labels = { + "Neoplastic cells": 1, + "Inflammatory": 2, + "Connective/Soft tissue cells": 3, + "Dead Cells": 4, + "Epithelial": 5, + } + split_dataset(datastore, "/localhome/sachi/Datasets/pannukeF", "pannuke", labels, None) def main_local(): @@ -385,7 +404,9 @@ def main_nuclei(): datefmt="%Y-%m-%d %H:%M:%S", ) - datastore = LocalDatastore("/localhome/sachi/Datasets/NuClick", extensions=("*.png", "*.npy")) + # s = "/localhome/sachi/Datasets/NuClick" + s = "/localhome/sachi/Datasets/pannukeF" + datastore = LocalDatastore(s, extensions=("*.png", "*.npy")) split_dataset(datastore, None, "nuclick", None, None, limit=0) diff --git a/sample-apps/pathology/main.py b/sample-apps/pathology/main.py index a414f4a45..8b73544b2 100644 --- a/sample-apps/pathology/main.py +++ b/sample-apps/pathology/main.py @@ -162,14 +162,7 @@ def main(): run_train = True home = str(Path.home()) - if run_train: - studies = f"{home}/Datasets/NuClick" - # studies = f"{home}/Data/Pathology/PanNuke" - # studies = "http://0.0.0.0:8080/api/v1" - else: - # studies = f"{home}/Datasets/Pathology/" - # studies = "C:\\Projects\\Pathology\\Test" - studies = f"{home}/Datasets/NuClick" + studies = f"{home}/Datasets/pannukeF" parser = argparse.ArgumentParser() parser.add_argument("-s", "--studies", default=studies) @@ -193,13 +186,14 @@ def train_nuclick(app): "name": "train_01", "model": model, "max_epochs": 100, - "dataset": "CacheDataset", # PersistentDataset, CacheDataset - "train_batch_size": 64, - "val_batch_size": 32, - "multi_gpu": False, + "dataset": "PersistentDataset", # PersistentDataset, CacheDataset + "train_batch_size": 128, + "val_batch_size": 64, + "multi_gpu": True, "val_split": 0.2, - "dataset_source": "nuclick", + "dataset_source": "none", "dataset_limit": 0, + "pretrained": False, }, ) From 7162cbef1c3b493d3ecd41e63e49aa090f8d573d Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Sat, 7 May 2022 19:09:15 -0700 Subject: [PATCH 4/7] rename transform Signed-off-by: Sachidanand Alle --- sample-apps/pathology/lib/trainers/nuclick.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sample-apps/pathology/lib/trainers/nuclick.py b/sample-apps/pathology/lib/trainers/nuclick.py index 919e467a6..5302d6fe5 100644 --- a/sample-apps/pathology/lib/trainers/nuclick.py +++ b/sample-apps/pathology/lib/trainers/nuclick.py @@ -42,7 +42,6 @@ ToTensord, Transform, ) -from skimage.measure import regionprops from tqdm import tqdm from monailabel.interfaces.datastore import Datastore @@ -112,7 +111,7 @@ def train_pre_transforms(self, context: Context): return [ LoadImaged(keys=("image", "label"), dtype=np.uint8), FilterImaged(keys="image", min_size=5), - FlattenLabel(keys="label"), + FlattenLabeld(keys="label"), AsChannelFirstd(keys="image"), AddChanneld(keys="label"), ExtractPatchd(keys=("image", "label"), patch_size=self.patch_size), @@ -150,7 +149,7 @@ def train_handlers(self, context: Context): return handlers -class FlattenLabel(MapTransform): +class FlattenLabeld(MapTransform): def __call__(self, data): d = dict(data) for key in self.keys: @@ -285,7 +284,7 @@ def exclusion_map(others, jitter_range=3, drop_rate=0.5): max_x = point_mask.shape[0] - 1 max_y = point_mask.shape[1] - 1 - stats = regionprops(others) + stats = skimage.measure.regionprops(others) for stat in stats: x, y = stat.centroid # random drop From 6b1bf8b8cfdfee6de0433ab0d241d36ba5e4737c Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Sat, 7 May 2022 19:25:29 -0700 Subject: [PATCH 5/7] Sync up changes for nuclick training Signed-off-by: Sachidanand Alle --- sample-apps/pathology/lib/trainers/nuclick.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sample-apps/pathology/lib/trainers/nuclick.py b/sample-apps/pathology/lib/trainers/nuclick.py index 5302d6fe5..725d34788 100644 --- a/sample-apps/pathology/lib/trainers/nuclick.py +++ b/sample-apps/pathology/lib/trainers/nuclick.py @@ -242,15 +242,12 @@ def _mask_relabeling(mask, min_area=5): class AddPointGuidanceSignald(RandomizableTransform): - def __init__( - self, image="image", label="label", others="others", perturb="distance", drop_rate=0.5, jitter_range=3 - ): + def __init__(self, image="image", label="label", others="others", drop_rate=0.5, jitter_range=3): super().__init__() self.image = image self.label = label self.others = others - self.perturb = perturb self.drop_rate = drop_rate self.jitter_range = jitter_range From 72cfa15550e97cf7a2032ce36d714998842a608c Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Sun, 8 May 2022 13:24:36 -0700 Subject: [PATCH 6/7] use monai bunet for nuclick Signed-off-by: Sachidanand Alle --- monailabel/tasks/train/basic_train.py | 7 +- .../monailabel/commands/RunInference.java | 13 +- sample-apps/pathology/lib/configs/nuclick.py | 12 +- sample-apps/pathology/lib/handlers.py | 14 +- sample-apps/pathology/lib/nets/__init__.py | 11 - sample-apps/pathology/lib/nets/nuclick.py | 486 ------------------ sample-apps/pathology/lib/nets/unet.py | 105 ---- sample-apps/pathology/lib/utils.py | 2 +- sample-apps/pathology/main.py | 3 +- 9 files changed, 32 insertions(+), 621 deletions(-) delete mode 100644 sample-apps/pathology/lib/nets/__init__.py delete mode 100644 sample-apps/pathology/lib/nets/nuclick.py delete mode 100644 sample-apps/pathology/lib/nets/unet.py diff --git a/monailabel/tasks/train/basic_train.py b/monailabel/tasks/train/basic_train.py index 1ea36d4da..ccd9aac1f 100644 --- a/monailabel/tasks/train/basic_train.py +++ b/monailabel/tasks/train/basic_train.py @@ -104,6 +104,7 @@ def __init__( stats_path=None, train_save_interval=20, val_interval=1, + n_saved=5, final_filename="checkpoint_final.pt", key_metric_filename="model.pt", model_dict_key="model", @@ -123,6 +124,7 @@ def __init__( :param stats_path: Path to save the train stats :param train_save_interval: checkpoint save interval for training :param val_interval: validation interval (run every x epochs) + :param n_saved: max checkpoints to save :param final_filename: name of final checkpoint that will be saved :param key_metric_filename: best key metric model file name :param model_dict_key: key to save network weights into checkpoint @@ -157,6 +159,7 @@ def __init__( self._train_save_interval = train_save_interval self._val_interval = val_interval + self._n_saved = n_saved self._final_filename = final_filename self._key_metric_filename = key_metric_filename self._model_dict_key = model_dict_key @@ -528,7 +531,7 @@ def _create_evaluator(self, context: Context): save_dict={self._model_dict_key: context.network}, save_key_metric=True, key_metric_filename=self._key_metric_filename, - n_saved=5, + n_saved=self._n_saved, ) ) @@ -560,7 +563,7 @@ def _create_trainer(self, context: Context): key_metric_filename=f"train_{self._key_metric_filename}" if context.evaluator else self._key_metric_filename, - n_saved=5, + n_saved=self._n_saved, ) ) diff --git a/plugins/qupath/src/main/java/qupath/lib/extension/monailabel/commands/RunInference.java b/plugins/qupath/src/main/java/qupath/lib/extension/monailabel/commands/RunInference.java index 1b75a0e23..119f5f170 100644 --- a/plugins/qupath/src/main/java/qupath/lib/extension/monailabel/commands/RunInference.java +++ b/plugins/qupath/src/main/java/qupath/lib/extension/monailabel/commands/RunInference.java @@ -94,17 +94,20 @@ public void run() { list.addIntParameter("Width", "Width", bbox[2]); list.addIntParameter("Height", "Height", bbox[3]); + boolean override = !info.models.get(selectedModel).nuclick; + list.addBooleanParameter("Override", "Override", override); + if (Dialogs.showParameterDialog("MONAILabel", list)) { String model = (String) list.getChoiceParameterValue("Model"); bbox[0] = list.getIntParameterValue("X").intValue(); bbox[1] = list.getIntParameterValue("Y").intValue(); bbox[2] = list.getIntParameterValue("Width").intValue(); bbox[3] = list.getIntParameterValue("Height").intValue(); + override = list.getBooleanParameterValue("Override").booleanValue(); selectedModel = model; selectedBBox = bbox; - boolean isNuClick = info.models.get(model).nuclick; - runInference(model, new HashSet(Arrays.asList(labels.get(model))), bbox, imageData, isNuClick); + runInference(model, new HashSet(Arrays.asList(labels.get(model))), bbox, imageData, override); } } catch (Exception ex) { ex.printStackTrace(); @@ -145,9 +148,9 @@ private int[] getBBOX(ROI roi) { } private void runInference(String model, Set labels, int[] bbox, ImageData imageData, - boolean isNuClick) throws SAXException, IOException, ParserConfigurationException, InterruptedException { + boolean override) throws SAXException, IOException, ParserConfigurationException, InterruptedException { logger.info("MONAILabel Annotation - Run Inference..."); - logger.info("Model: " + model + "; IsNuClick: " + isNuClick + "; Labels: " + labels); + logger.info("Model: " + model + "; override: " + override + "; Labels: " + labels); String image = Utils.getNameWithoutExtension(imageData.getServerPath()); @@ -163,7 +166,7 @@ private void runInference(String model, Set labels, int[] bbox, ImageDat Document dom = MonaiLabelClient.infer(model, image, req); NodeList annotation_list = dom.getElementsByTagName("Annotation"); - int count = updateAnnotations(labels, annotation_list, roi, imageData, !isNuClick); + int count = updateAnnotations(labels, annotation_list, roi, imageData, override); // Update hierarchy to see changes in QuPath's hierarchy QP.fireHierarchyUpdate(imageData.getHierarchy()); diff --git a/sample-apps/pathology/lib/configs/nuclick.py b/sample-apps/pathology/lib/configs/nuclick.py index fe61e3149..42533c2c3 100644 --- a/sample-apps/pathology/lib/configs/nuclick.py +++ b/sample-apps/pathology/lib/configs/nuclick.py @@ -16,7 +16,7 @@ import lib.infers import lib.trainers -from lib.nets import UNet +from monai.networks.nets import BasicUNet from monailabel.interfaces.config import TaskConfig from monailabel.interfaces.tasks.infer import InferTask @@ -42,11 +42,16 @@ def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, ** # Download PreTrained Model if strtobool(self.conf.get("use_pretrained_model", "true")): - url = f"{self.PRE_TRAINED_PATH}/NuClick_UNet_40xAll.pth" + url = f"{self.PRE_TRAINED_PATH}/pathology_nuclick_bunet.pt" download_file(url, self.path[0]) # Network - self.network = UNet(n_channels=5, n_classes=1) + self.network = BasicUNet( + spatial_dims=2, + in_channels=5, + out_channels=1, + features=(32, 64, 128, 256, 512, 32), + ) def infer(self) -> Union[InferTask, Dict[str, InferTask]]: task: InferTask = lib.infers.NuClick( @@ -67,6 +72,7 @@ def trainer(self) -> Optional[TrainTask]: publish_path=self.path[1], labels=self.labels, description="Train Nuclei DeepEdit Model", + train_save_interval=1, config={ "max_epochs": 10, "train_batch_size": 64, diff --git a/sample-apps/pathology/lib/handlers.py b/sample-apps/pathology/lib/handlers.py index ed32a7e2d..288b46d96 100644 --- a/sample-apps/pathology/lib/handlers.py +++ b/sample-apps/pathology/lib/handlers.py @@ -174,15 +174,15 @@ def write_images(self, batch_data, output_data, epoch): break def write_region_metrics(self, epoch): - metric_sum = 0 - for region in self.metric_data: - metric = self.metric_data[region].mean() - self.logger.info(f"Epoch[{epoch}] Metrics -- Region: {region:0>2d}, {self.tag_name}: {metric:.4f}") + if len(self.metric_data) > 1: + metric_sum = 0 + for region in self.metric_data: + metric = self.metric_data[region].mean() + self.logger.info(f"Epoch[{epoch}] Metrics -- Region: {region:0>2d}, {self.tag_name}: {metric:.4f}") - self.writer.add_scalar(f"dice_{region:0>2d}", metric, epoch) - metric_sum += metric + self.writer.add_scalar(f"dice_{region:0>2d}", metric, epoch) + metric_sum += metric - if len(self.metric_data) > 1: metric_avg = metric_sum / len(self.metric_data) self.writer.add_scalar("dice_regions_avg", metric_avg, epoch) diff --git a/sample-apps/pathology/lib/nets/__init__.py b/sample-apps/pathology/lib/nets/__init__.py deleted file mode 100644 index 2bca7a6c9..000000000 --- a/sample-apps/pathology/lib/nets/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from .unet import UNet diff --git a/sample-apps/pathology/lib/nets/nuclick.py b/sample-apps/pathology/lib/nets/nuclick.py deleted file mode 100644 index b8d6ebe03..000000000 --- a/sample-apps/pathology/lib/nets/nuclick.py +++ /dev/null @@ -1,486 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -bn_axis = 1 - - -class DoubleConv(nn.Module): - """(convolution => [BN] => ReLU) * 2""" - - def __init__(self, in_channels, out_channels, mid_channels=None): - super().__init__() - if not mid_channels: - mid_channels = out_channels - self.double_conv = nn.Sequential( - nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(mid_channels), - nn.ReLU(inplace=True), - nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - ) - - def forward(self, x): - return self.double_conv(x) - - -class Down(nn.Module): - """Downscaling with maxpool then double conv""" - - def __init__(self, in_channels, out_channels): - super().__init__() - self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)) - - def forward(self, x): - return self.maxpool_conv(x) - - -class Up(nn.Module): - """Upscaling then double conv""" - - def __init__(self, in_channels, out_channels, bilinear=True): - super().__init__() - - # if bilinear, use the normal convolutions to reduce the number of channels - if bilinear: - self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) - self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) - else: - self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) - self.conv = DoubleConv(in_channels, out_channels) - - def forward(self, x1, x2): - x1 = self.up(x1) - # input is CHW - diffY = x2.size()[2] - x1.size()[2] - diffX = x2.size()[3] - x1.size()[3] - - x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) - # if you have padding issues, see - # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a - # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd - x = torch.cat([x2, x1], dim=1) - return self.conv(x) - - -class OutConv(nn.Module): - def __init__(self, in_channels, out_channels): - super().__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) - - def forward(self, x): - return self.conv(x) - - -# --------------------New layers development in progress----------------- -"""(convolution => [BN] => ReLU/sigmoid)""" -"""No regularizer""" - - -class Conv_Bn_Relu(nn.Module): - def __init__( - self, - in_channels, - out_channels=32, - kernelSize=(3, 3), - strds=(1, 1), - useBias=False, - dilatationRate=(1, 1), - actv="relu", - doBatchNorm=True, - ): - - super().__init__() - if isinstance(kernelSize, int): - kernelSize = (kernelSize, kernelSize) - if isinstance(strds, int): - strds = (strds, strds) - - self.conv_bn_relu = self.get_block( - in_channels, out_channels, kernelSize, strds, useBias, dilatationRate, actv, doBatchNorm - ) - - def forward(self, input): - return self.conv_bn_relu(input) - - def get_block(self, in_channels, out_channels, kernelSize, strds, useBias, dilatationRate, actv, doBatchNorm): - - layers = [] - - conv1 = nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernelSize, - stride=strds, - dilation=dilatationRate, - bias=useBias, - padding="same", - padding_mode="zeros", - ) - - if actv == "selu": - # (Can't find 'lecun_normal' equivalent in PyTorch) - torch.nn.init.xavier_normal_(conv1.weight) - else: - torch.nn.init.xavier_uniform_(conv1.weight) - - layers.append(conv1) - - if actv != "selu" and doBatchNorm: - layers.append(nn.BatchNorm2d(num_features=out_channels, eps=1.001e-5)) - - if actv == "relu": - layers.append(nn.ReLU()) - elif actv == "sigmoid": - layers.append(nn.Sigmoid()) - elif actv == "selu": - layers.append(nn.SELU()) - - block = nn.Sequential(*layers) - return block - - -"""Multiscale Conv Block""" - - -class Multiscale_Conv_Block(nn.Module): - def __init__( - self, - in_channels, - kernelSizes, - dilatationRates, - out_channels=32, - strds=(1, 1), - actv="relu", - useBias=False, - isDense=True, - ): - - super().__init__() - - # Initialise conv blocks - if isDense: - self.conv_block_0 = Conv_Bn_Relu( - in_channels=in_channels, - out_channels=4 * out_channels, - kernelSize=1, - strds=strds, - actv=actv, - useBias=useBias, - ) - self.conv_block_5 = Conv_Bn_Relu( - in_channels=in_channels, - out_channels=out_channels, - kernelSize=3, - strds=strds, - actv=actv, - useBias=useBias, - ) - else: - self.conv_block_0 = None - self.conv_block_5 = None - - self.conv_block_1 = Conv_Bn_Relu( - in_channels=in_channels, - out_channels=out_channels, - kernelSize=kernelSizes[0], - strds=strds, - actv=actv, - useBias=useBias, - dilatationRate=(dilatationRates[0], dilatationRates[0]), - ) - - self.conv_block_2 = Conv_Bn_Relu( - in_channels=in_channels, - out_channels=out_channels, - kernelSize=kernelSizes[1], - strds=strds, - actv=actv, - useBias=useBias, - dilatationRate=(dilatationRates[1], dilatationRates[1]), - ) - - self.conv_block_3 = Conv_Bn_Relu( - in_channels=in_channels, - out_channels=out_channels, - kernelSize=kernelSizes[2], - strds=strds, - actv=actv, - useBias=useBias, - dilatationRate=(dilatationRates[2], dilatationRates[2]), - ) - - self.conv_block_4 = Conv_Bn_Relu( - in_channels=in_channels, - out_channels=out_channels, - kernelSize=kernelSizes[3], - strds=strds, - actv=actv, - useBias=useBias, - dilatationRate=(dilatationRates[3], dilatationRates[3]), - ) - - def forward(self, input_map): - # If isDense == True - if self.conv_block_0 is not None: - conv0 = self.conv_block_0(input_map) - else: - conv0 = input_map - - conv1 = self.conv_block_1(conv0) - conv2 = self.conv_block_2(conv0) - conv3 = self.conv_block_3(conv0) - conv4 = self.conv_block_4(conv0) - - # (Not sure about bn_axis) - output_map = torch.cat([conv1, conv2, conv3, conv4], dim=bn_axis) - - # If isDense == True - if self.conv_block_5 is not None: - output_map = self.conv_block_5(output_map) - # (Not sure about bn_axis) - output_map = torch.cat([input_map, output_map], dim=bn_axis) - - return output_map - - -"""Residual_Conv""" - - -class Residual_Conv(nn.Module): - def __init__( - self, - in_channels, - out_channels=32, - kernelSize=(3, 3), - strds=(1, 1), - actv="relu", - useBias=False, - dilatationRate=(1, 1), - ): - super().__init__() - - if actv == "selu": - self.conv_block_1 = Conv_Bn_Relu( - in_channels, - out_channels, - kernelSize=kernelSize, - strds=strds, - actv="None", - useBias=useBias, - dilatationRate=dilatationRate, - doBatchNorm=False, - ) - self.conv_block_2 = Conv_Bn_Relu( - in_channels, - out_channels, - kernelSize=kernelSize, - strds=strds, - actv="None", - useBias=useBias, - dilatationRate=dilatationRate, - doBatchNorm=False, - ) - self.activation = nn.SELU() - else: - self.conv_block_1 = Conv_Bn_Relu( - in_channels, - out_channels, - kernelSize=kernelSize, - strds=strds, - actv="None", - useBias=useBias, - dilatationRate=dilatationRate, - doBatchNorm=True, - ) - self.conv_block_2 = Conv_Bn_Relu( - out_channels, - out_channels, - kernelSize=kernelSize, - strds=strds, - actv="None", - useBias=useBias, - dilatationRate=dilatationRate, - doBatchNorm=True, - ) - - if actv == "relu": - self.activation = nn.ReLU() - - if actv == "sigmoid": - self.activation = nn.Sigmoid() - - def forward(self, input): - conv1 = self.conv_block_1(input) - conv2 = self.conv_block_2(conv1) - - out = torch.add(conv1, conv2) - out = self.activation(out) - return out - - -class NuClick_NN(nn.Module): - def __init__(self, n_channels, n_classes): - super().__init__() - self.net_name = "NuClick" - - self.n_channels = n_channels - self.n_classes = n_classes - - # -------------Conv_Bn_Relu blocks------------ - self.conv_block_1 = nn.Sequential( - Conv_Bn_Relu(in_channels=self.n_channels, out_channels=64, kernelSize=7), - Conv_Bn_Relu(in_channels=64, out_channels=32, kernelSize=5), - Conv_Bn_Relu(in_channels=32, out_channels=32, kernelSize=3), - ) - - self.conv_block_2 = nn.Sequential( - Conv_Bn_Relu(in_channels=64, out_channels=64), - Conv_Bn_Relu(in_channels=64, out_channels=32), - Conv_Bn_Relu(in_channels=32, out_channels=32), - ) - - self.conv_block_3 = Conv_Bn_Relu( - in_channels=32, out_channels=self.n_classes, kernelSize=(1, 1), actv=None, useBias=True, doBatchNorm=False - ) - - # -------------Residual_Conv blocks------------ - self.residual_block_1 = nn.Sequential( - Residual_Conv(in_channels=32, out_channels=64), Residual_Conv(in_channels=64, out_channels=64) - ) - - self.residual_block_2 = Residual_Conv(in_channels=64, out_channels=128) - - self.residual_block_3 = Residual_Conv(in_channels=128, out_channels=128) - - self.residual_block_4 = nn.Sequential( - Residual_Conv(in_channels=128, out_channels=256), - Residual_Conv(in_channels=256, out_channels=256), - Residual_Conv(in_channels=256, out_channels=256), - ) - - self.residual_block_5 = nn.Sequential( - Residual_Conv(in_channels=256, out_channels=512), - Residual_Conv(in_channels=512, out_channels=512), - Residual_Conv(in_channels=512, out_channels=512), - ) - - self.residual_block_6 = nn.Sequential( - Residual_Conv(in_channels=512, out_channels=1024), Residual_Conv(in_channels=1024, out_channels=1024) - ) - - self.residual_block_7 = nn.Sequential( - Residual_Conv(in_channels=1024, out_channels=512), Residual_Conv(in_channels=512, out_channels=256) - ) - - self.residual_block_8 = Residual_Conv(in_channels=512, out_channels=256) - - self.residual_block_9 = Residual_Conv(in_channels=256, out_channels=256) - - self.residual_block_10 = nn.Sequential( - Residual_Conv(in_channels=256, out_channels=128), Residual_Conv(in_channels=128, out_channels=128) - ) - - self.residual_block_11 = Residual_Conv(in_channels=128, out_channels=64) - - self.residual_block_12 = Residual_Conv(in_channels=64, out_channels=64) - - # -------------Multiscale_Conv_Block blocks------------ - self.multiscale_block_1 = Multiscale_Conv_Block( - in_channels=128, out_channels=32, kernelSizes=[3, 3, 5, 5], dilatationRates=[1, 3, 3, 6], isDense=False - ) - - self.multiscale_block_2 = Multiscale_Conv_Block( - in_channels=256, out_channels=64, kernelSizes=[3, 3, 5, 5], dilatationRates=[1, 3, 2, 3], isDense=False - ) - - self.multiscale_block_3 = Multiscale_Conv_Block( - in_channels=64, out_channels=16, kernelSizes=[3, 3, 5, 7], dilatationRates=[1, 3, 2, 6], isDense=False - ) - - # -------------MaxPool2d blocks------------ - self.pool_block_1 = nn.MaxPool2d(kernel_size=(2, 2)) - self.pool_block_2 = nn.MaxPool2d(kernel_size=(2, 2)) - self.pool_block_3 = nn.MaxPool2d(kernel_size=(2, 2)) - self.pool_block_4 = nn.MaxPool2d(kernel_size=(2, 2)) - self.pool_block_5 = nn.MaxPool2d(kernel_size=(2, 2)) - - # -------------ConvTranspose2d blocks------------ - self.conv_transpose_1 = nn.ConvTranspose2d( - in_channels=1024, - out_channels=512, - kernel_size=2, - stride=(2, 2), - ) - - self.conv_transpose_2 = nn.ConvTranspose2d( - in_channels=256, - out_channels=256, - kernel_size=2, - stride=(2, 2), - ) - - self.conv_transpose_3 = nn.ConvTranspose2d( - in_channels=256, - out_channels=128, - kernel_size=2, - stride=(2, 2), - ) - - self.conv_transpose_4 = nn.ConvTranspose2d( - in_channels=128, - out_channels=64, - kernel_size=2, - stride=(2, 2), - ) - - self.conv_transpose_5 = nn.ConvTranspose2d( - in_channels=64, - out_channels=32, - kernel_size=2, - stride=(2, 2), - ) - - def forward(self, input): - conv1 = self.conv_block_1(input) # conv1: 32 channels - pool1 = self.pool_block_1(conv1) # poo1: 32 channels - - conv2 = self.residual_block_1(pool1) # conv2: 64 channels - pool2 = self.pool_block_2(conv2) # pool2: 64 channels - - conv3 = self.residual_block_2(pool2) # conv3: 128 channels - conv3 = self.multiscale_block_1(conv3) # conv3: 128 channels - conv3 = self.residual_block_3(conv3) # conv3: 128 channels - pool3 = self.pool_block_3(conv3) # pool3: 128 channels - - conv4 = self.residual_block_4(pool3) # conv4: 256 channels - pool4 = self.pool_block_4(conv4) # pool4: 512 channels - - conv5 = self.residual_block_5(pool4) # conv5: 512 channels - pool5 = self.pool_block_5(conv5) # pool5: 512 channels - - conv51 = self.residual_block_6(pool5) # conv51: 1024 channels - - up61 = torch.cat([self.conv_transpose_1(conv51), conv5], dim=1) # up61: 1024 channels - conv61 = self.residual_block_7(up61) # conv61: 256 channels - - up6 = torch.cat([self.conv_transpose_2(conv61), conv4], dim=1) # up6: 512 channels - conv6 = self.residual_block_8(up6) # conv6: 256 channels - conv6 = self.multiscale_block_2(conv6) # conv6: 256 channels - conv6 = self.residual_block_9(conv6) # conv6: 256 channels - - up7 = torch.cat([self.conv_transpose_3(conv6), conv3], dim=1) # up7: 256 channels - conv7 = self.residual_block_10(up7) # conv7: 128 channels - - up8 = torch.cat([self.conv_transpose_4(conv7), conv2], dim=1) # up8: 128 channels - conv8 = self.residual_block_11(up8) # conv8: 64 channels - conv8 = self.multiscale_block_3(conv8) # conv8: 64 channels - conv8 = self.residual_block_12(conv8) # conv8: 64 channels - - up9 = torch.cat([self.conv_transpose_5(conv8), conv1], dim=1) # up9: 64 channels - conv9 = self.conv_block_2(up9) # conv9: 32 channels - - conv10 = self.conv_block_3(conv9) # conv10: out_channels - - return conv10 diff --git a/sample-apps/pathology/lib/nets/unet.py b/sample-apps/pathology/lib/nets/unet.py deleted file mode 100644 index 60366523b..000000000 --- a/sample-apps/pathology/lib/nets/unet.py +++ /dev/null @@ -1,105 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class DoubleConv(nn.Module): - """(convolution => [BN] => ReLU) * 2""" - - def __init__(self, in_channels, out_channels, mid_channels=None): - super().__init__() - if not mid_channels: - mid_channels = out_channels - self.double_conv = nn.Sequential( - nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(mid_channels), - nn.ReLU(inplace=True), - nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - ) - - def forward(self, x): - return self.double_conv(x) - - -class Down(nn.Module): - """Downscaling with maxpool then double conv""" - - def __init__(self, in_channels, out_channels): - super().__init__() - self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)) - - def forward(self, x): - return self.maxpool_conv(x) - - -class Up(nn.Module): - """Upscaling then double conv""" - - def __init__(self, in_channels, out_channels, bilinear=True): - super().__init__() - - # if bilinear, use the normal convolutions to reduce the number of channels - if bilinear: - self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) - self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) - else: - self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) - self.conv = DoubleConv(in_channels, out_channels) - - def forward(self, x1, x2): - x1 = self.up(x1) - # input is CHW - diffY = x2.size()[2] - x1.size()[2] - diffX = x2.size()[3] - x1.size()[3] - - x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) - # if you have padding issues, see - # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a - # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd - x = torch.cat([x2, x1], dim=1) - return self.conv(x) - - -class OutConv(nn.Module): - def __init__(self, in_channels, out_channels): - super().__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) - - def forward(self, x): - return self.conv(x) - - -class UNet(nn.Module): - def __init__(self, n_channels, n_classes, bilinear=True): - super().__init__() - self.net_name = "UNet" - self.n_channels = n_channels - self.n_classes = n_classes - self.bilinear = bilinear - - self.inc = DoubleConv(n_channels, 64) - self.down1 = Down(64, 128) - self.down2 = Down(128, 256) - self.down3 = Down(256, 512) - factor = 2 if bilinear else 1 - self.down4 = Down(512, 1024 // factor) - self.up1 = Up(1024, 512 // factor, bilinear) - self.up2 = Up(512, 256 // factor, bilinear) - self.up3 = Up(256, 128 // factor, bilinear) - self.up4 = Up(128, 64, bilinear) - self.outc = OutConv(64, n_classes) - - def forward(self, x): - x1 = self.inc(x) - x2 = self.down1(x1) - x3 = self.down2(x2) - x4 = self.down3(x3) - x5 = self.down4(x4) - x = self.up1(x5, x4) - x = self.up2(x, x3) - x = self.up3(x, x2) - x = self.up4(x, x1) - logits = self.outc(x) - return logits diff --git a/sample-apps/pathology/lib/utils.py b/sample-apps/pathology/lib/utils.py index 9dd35280b..e4eb7aca6 100644 --- a/sample-apps/pathology/lib/utils.py +++ b/sample-apps/pathology/lib/utils.py @@ -210,7 +210,7 @@ def split_nuclei_dataset(d, centroid_key="centroid", mask_value_key="mask_value" stats = regionprops(labels) for stat in stats: if stat.area < min_area: - logger.info(f"++++ Ignored label with smaller area => ( {stat.area} < {min_area})") + logger.debug(f"++++ Ignored label with smaller area => ( {stat.area} < {min_area})") continue x, y = stat.centroid diff --git a/sample-apps/pathology/main.py b/sample-apps/pathology/main.py index 8b73544b2..fd959dfa8 100644 --- a/sample-apps/pathology/main.py +++ b/sample-apps/pathology/main.py @@ -185,7 +185,7 @@ def train_nuclick(app): request={ "name": "train_01", "model": model, - "max_epochs": 100, + "max_epochs": 10, "dataset": "PersistentDataset", # PersistentDataset, CacheDataset "train_batch_size": 128, "val_batch_size": 64, @@ -194,6 +194,7 @@ def train_nuclick(app): "dataset_source": "none", "dataset_limit": 0, "pretrained": False, + "n_saved": 10, }, ) From de8b20cfd8b496becac95cff6931a53ac7ab11fd Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Sun, 8 May 2022 13:32:24 -0700 Subject: [PATCH 7/7] fix log Signed-off-by: Sachidanand Alle --- sample-apps/pathology/lib/handlers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sample-apps/pathology/lib/handlers.py b/sample-apps/pathology/lib/handlers.py index 288b46d96..f19484ca0 100644 --- a/sample-apps/pathology/lib/handlers.py +++ b/sample-apps/pathology/lib/handlers.py @@ -141,7 +141,7 @@ def write_images(self, batch_data, output_data, epoch): label[y == region] = region self.logger.info( - "{} - {} - Image: {}; Label: {} (nz: {}); Pred: {} (nz: {}); (pos-nz: {}); (neg-nz: {})".format( + "{} - {} - Image: {}; Label: {} (nz: {}); Pred: {} (nz: {}); Sig: (pos-nz: {}, neg-nz: {})".format( bidx, region, image.shape, @@ -149,8 +149,8 @@ def write_images(self, batch_data, output_data, epoch): np.count_nonzero(label), y_pred.shape, np.count_nonzero(y_pred[region]), - np.count_nonzero(image[-2]), - np.count_nonzero(image[-1]), + np.count_nonzero(image[3]) if image.shape == 5 else 0, + np.count_nonzero(image[4]) if image.shape == 5 else 0, ) )