diff --git a/pathology/hovernet/README.MD b/pathology/hovernet/README.MD new file mode 100644 index 0000000000..fef6268eb2 --- /dev/null +++ b/pathology/hovernet/README.MD @@ -0,0 +1,93 @@ +# HoVerNet Examples + +This folder contains ignite version examples to run train and validate a HoVerNet model. +It also has torch version notebooks to run training and evaluation. +
+
+implementation based on:
+
+Simon Graham et al., HoVer-Net: Simultaneous Segmentation and Classification of Nuclei in Multi-Tissue Histology Images.' Medical Image Analysis, (2019). https://arxiv.org/abs/1812.06499
+
+### 1. Data
+
+CoNSeP datasets which are used in the examples can be downloaded from https://warwick.ac.uk/fac/cross_fac/tia/data/hovernet/.
+- First download CoNSeP dataset to `data_root`.
+- Run prepare_patches.py to prepare patches from images.
+
+### 2. Questions and bugs
+
+- For questions relating to the use of MONAI, please us our [Discussions tab](https://github.com/Project-MONAI/MONAI/discussions) on the main repository of MONAI.
+- For bugs relating to MONAI functionality, please create an issue on the [main repository](https://github.com/Project-MONAI/MONAI/issues).
+- For bugs relating to the running of a tutorial, please create an issue in [this repository](https://github.com/Project-MONAI/Tutorials/issues).
+
+
+### 3. List of notebooks and examples
+#### [Prepare Your Data](./prepare_patches.py)
+This example is used to prepare patches from tiles referring to the implementation from https://github.com/vqdang/hover_net/blob/master/extract_patches.py. Prepared patches will be saved in `data_root`/Prepared.
+
+```bash
+# Run to know all possible options
+python ./prepare_patches.py -h
+
+# Prepare patches from images
+python ./prepare_patches.py \
+ --root `data_root`
+```
+
+#### [HoVerNet Training](./training.py)
+This example uses MONAI workflow to train a HoVerNet model on prepared CoNSeP dataset.
+Since HoVerNet is training via a two-stage approach. First initialised the model with pre-trained weights on the [ImageNet dataset](https://ieeexplore.ieee.org/document/5206848), trained only the decoders for the first 50 epochs, and then fine-tuned all layers for another 50 epochs. We need to specify `--stage` during training.
+
+Each user is responsible for checking the content of models/datasets and the applicable licenses and determining if suitable for the intended use.
+The license for the pre-trained model used in examples is different than MONAI license. Please check the source where these weights are obtained from:
+https://github.com/vqdang/hover_net#data-format
+
+
+```bash
+# Run to know all possible options
+python ./training.py -h
+
+# Train a hovernet model on single-gpu(replace with your own ckpt path)
+export CUDA_VISIBLE_DEVICES=0; python training.py \
+ --ep 50 \
+ --stage 0 \
+ --bs 16 \
+ --root `save_root`
+export CUDA_VISIBLE_DEVICES=0; python training.py \
+ --ep 50 \
+ --stage 1 \
+ --bs 4 \
+ --root `save_root` \
+ --ckpt logs/stage0/checkpoint_epoch=50.pt
+
+# Train a hovernet model on multi-gpu (NVIDIA)(replace with your own ckpt path)
+torchrun --nnodes=1 --nproc_per_node=2 training.py \
+ --ep 50 \
+ --bs 8 \
+ --root `save_root` \
+ --stage 0
+torchrun --nnodes=1 --nproc_per_node=2 training.py \
+ --ep 50 \
+ --bs 2 \
+ --root `save_root` \
+ --stage 1 \
+ --ckpt logs/stage0/checkpoint_epoch=50.pt
+```
+
+#### [HoVerNet Validation](./evaluation.py)
+This example uses MONAI workflow to evaluate the trained HoVerNet model on prepared test data from CoNSeP dataset.
+With their metrics on original mode. We reproduce the results with Dice: 0.82762; PQ: 0.48976; F1d: 0.73592.
+```bash
+# Run to know all possible options
+python ./evaluation.py -h
+
+# Evaluate a HoVerNet model
+python ./evaluation.py
+ --root `save_root` \
+ --ckpt logs/stage0/checkpoint_epoch=50.pt
+```
+
+## Disclaimer
+
+This is an example, not to be used for diagnostic purposes.
diff --git a/pathology/hovernet/evaluation.py b/pathology/hovernet/evaluation.py
new file mode 100644
index 0000000000..476da843ca
--- /dev/null
+++ b/pathology/hovernet/evaluation.py
@@ -0,0 +1,150 @@
+import os
+import glob
+import logging
+import torch
+from argparse import ArgumentParser
+from monai.data import DataLoader, CacheDataset
+from monai.networks.nets import HoVerNet
+from monai.engines import SupervisedEvaluator
+from monai.transforms import (
+ LoadImaged,
+ Lambdad,
+ Activationsd,
+ Compose,
+ CastToTyped,
+ ComputeHoVerMapsd,
+ ScaleIntensityRanged,
+ CenterSpatialCropd,
+)
+from monai.handlers import (
+ MeanDice,
+ StatsHandler,
+ CheckpointLoader,
+)
+from monai.utils.enums import HoVerNetBranch
+from monai.apps.pathology.handlers.utils import from_engine_hovernet
+from monai.apps.pathology.engines.utils import PrepareBatchHoVerNet
+from skimage import measure
+
+
+def prepare_data(data_dir, phase):
+ data_dir = os.path.join(data_dir, phase)
+
+ images = list(sorted(
+ glob.glob(os.path.join(data_dir, "*/*image.npy"))))
+ inst_maps = list(sorted(
+ glob.glob(os.path.join(data_dir, "*/*inst_map.npy"))))
+ type_maps = list(sorted(
+ glob.glob(os.path.join(data_dir, "*/*type_map.npy"))))
+
+ data_dicts = [
+ {"image": _image, "label_inst": _inst_map, "label_type": _type_map}
+ for _image, _inst_map, _type_map in zip(images, inst_maps, type_maps)
+ ]
+
+ return data_dicts
+
+
+def run(cfg):
+ if cfg["mode"].lower() == "original":
+ cfg["patch_size"] = [270, 270]
+ cfg["out_size"] = [80, 80]
+ elif cfg["mode"].lower() == "fast":
+ cfg["patch_size"] = [256, 256]
+ cfg["out_size"] = [164, 164]
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ val_transforms = Compose(
+ [
+ LoadImaged(keys=["image", "label_inst", "label_type"], image_only=True),
+ Lambdad(keys="label_inst", func=lambda x: measure.label(x)),
+ CastToTyped(keys=["image", "label_inst"], dtype=torch.int),
+ CenterSpatialCropd(
+ keys="image",
+ roi_size=cfg["patch_size"],
+ ),
+ ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
+ ComputeHoVerMapsd(keys="label_inst"),
+ Lambdad(keys="label_inst", func=lambda x: x > 0, overwrite="label"),
+ CenterSpatialCropd(
+ keys=["label", "hover_label_inst", "label_inst", "label_type"],
+ roi_size=cfg["out_size"],
+ ),
+ CastToTyped(keys=["image", "label_inst", "label_type"], dtype=torch.float32),
+ ]
+ )
+
+ # Create MONAI DataLoaders
+ valid_data = prepare_data(cfg["root"], "valid")
+ valid_ds = CacheDataset(data=valid_data, transform=val_transforms, cache_rate=1.0, num_workers=4)
+ val_loader = DataLoader(
+ valid_ds,
+ batch_size=cfg["batch_size"],
+ num_workers=cfg["num_workers"],
+ pin_memory=torch.cuda.is_available()
+ )
+
+ # initialize model
+ model = HoVerNet(
+ mode=cfg["mode"],
+ in_channels=3,
+ out_classes=cfg["out_classes"],
+ act=("relu", {"inplace": True}),
+ norm="batch",
+ pretrained_url=None,
+ freeze_encoder=False,
+ ).to(device)
+
+ post_process_np = Compose([
+ Activationsd(keys=HoVerNetBranch.NP.value, softmax=True),
+ Lambdad(keys=HoVerNetBranch.NP.value, func=lambda x: x[1: 2, ...] > 0.5)])
+ post_process = Lambdad(keys="pred", func=post_process_np)
+
+ # Evaluator
+ val_handlers = [
+ CheckpointLoader(load_path=cfg["ckpt_path"], load_dict={"net": model}),
+ StatsHandler(output_transform=lambda x: None),
+ ]
+ evaluator = SupervisedEvaluator(
+ device=device,
+ val_data_loader=val_loader,
+ prepare_batch=PrepareBatchHoVerNet(extra_keys=['label_type', 'hover_label_inst']),
+ network=model,
+ postprocessing=post_process,
+ key_val_metric={"val_dice": MeanDice(include_background=False, output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value))},
+ val_handlers=val_handlers,
+ amp=cfg["amp"],
+ )
+
+ state = evaluator.run()
+ print(state)
+
+
+def main():
+ parser = ArgumentParser(description="Tumor detection on whole slide pathology images.")
+ parser.add_argument(
+ "--root",
+ type=str,
+ default="/workspace/Data/CoNSeP/Prepared/consep",
+ help="root data dir",
+ )
+
+ parser.add_argument("--bs", type=int, default=16, dest="batch_size", help="batch size")
+ parser.add_argument("--no-amp", action="store_false", dest="amp", help="deactivate amp")
+ parser.add_argument("--classes", type=int, default=5, dest="out_classes", help="output classes")
+ parser.add_argument("--mode", type=str, default="original", help="choose either `original` or `fast`")
+
+ parser.add_argument("--cpu", type=int, default=8, dest="num_workers", help="number of workers")
+ parser.add_argument("--use_gpu", type=bool, default=True, dest="use_gpu", help="whether to use gpu")
+ parser.add_argument("--ckpt", type=str, dest="ckpt_path", help="checkpoint path")
+
+ args = parser.parse_args()
+ cfg = vars(args)
+ print(cfg)
+
+ logging.basicConfig(level=logging.INFO)
+ run(cfg)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pathology/hovernet/prepare_patches.py b/pathology/hovernet/prepare_patches.py
new file mode 100644
index 0000000000..bf8e90f9a4
--- /dev/null
+++ b/pathology/hovernet/prepare_patches.py
@@ -0,0 +1,228 @@
+import os
+import math
+import tqdm
+import glob
+import shutil
+import pathlib
+
+import numpy as np
+import scipy.io as sio
+from PIL import Image
+from argparse import ArgumentParser
+
+
+def load_img(path):
+ return np.array(Image.open(path).convert("RGB"))
+
+
+def load_ann(path):
+ # assumes that ann is HxW
+ ann_inst = sio.loadmat(path)["inst_map"]
+ ann_type = sio.loadmat(path)["type_map"]
+
+ # merge classes for CoNSeP (utilise 3 nuclei classes and background keep the same with paper)
+ ann_type[(ann_type == 3) | (ann_type == 4)] = 3
+ ann_type[(ann_type == 5) | (ann_type == 6) | (ann_type == 7)] = 4
+
+ ann = np.dstack([ann_inst, ann_type])
+ ann = ann.astype("int32")
+
+ return ann
+
+
+class PatchExtractor():
+ """Extractor to generate patches with or without padding.
+ Turn on debug mode to see how it is done.
+
+ Args:
+ x : input image, should be of shape HWC
+ patch_size : a tuple of (h, w)
+ step_size : a tuple of (h, w)
+ Return:
+ a list of sub patches, each patch has dtype same as x
+
+ Examples:
+ >>> xtractor = PatchExtractor((450, 450), (120, 120))
+ >>> img = np.full([1200, 1200, 3], 255, np.uint8)
+ >>> patches = xtractor.extract(img, 'mirror')
+
+ """
+
+ def __init__(self, patch_size, step_size):
+ self.patch_type = "mirror"
+ self.patch_size = patch_size
+ self.step_size = step_size
+
+ def __get_patch(self, x, ptx):
+ pty = (ptx[0] + self.patch_size[0], ptx[1] + self.patch_size[1])
+ win = x[ptx[0] : pty[0], ptx[1] : pty[1]]
+ assert (
+ win.shape[0] == self.patch_size[0] and win.shape[1] == self.patch_size[1]
+ ), "[BUG] Incorrect Patch Size {0}".format(win.shape)
+ return win
+
+ def __extract_valid(self, x):
+ """Extracted patches without padding, only work in case patch_size > step_size.
+
+ Note: to deal with the remaining portions which are at the boundary a.k.a
+ those which do not fit when slide left->right, top->bottom), we flip
+ the sliding direction then extract 1 patch starting from right / bottom edge.
+ There will be 1 additional patch extracted at the bottom-right corner.
+
+ Args:
+ x : input image, should be of shape HWC
+ patch_size : a tuple of (h, w)
+ step_size : a tuple of (h, w)
+ Return:
+ a list of sub patches, each patch is same dtype as x
+
+ """
+ im_h = x.shape[0]
+ im_w = x.shape[1]
+
+ def extract_infos(length, patch_size, step_size):
+ flag = (length - patch_size) % step_size != 0
+ last_step = math.floor((length - patch_size) / step_size)
+ last_step = (last_step + 1) * step_size
+ return flag, last_step
+
+ h_flag, h_last = extract_infos(im_h, self.patch_size[0], self.step_size[0])
+ w_flag, w_last = extract_infos(im_w, self.patch_size[1], self.step_size[1])
+
+ sub_patches = []
+ # Deal with valid block
+ for row in range(0, h_last, self.step_size[0]):
+ for col in range(0, w_last, self.step_size[1]):
+ win = self.__get_patch(x, (row, col))
+ sub_patches.append(win)
+ # Deal with edge case
+ if h_flag:
+ row = im_h - self.patch_size[0]
+ for col in range(0, w_last, self.step_size[1]):
+ win = self.__get_patch(x, (row, col))
+ sub_patches.append(win)
+ if w_flag:
+ col = im_w - self.patch_size[1]
+ for row in range(0, h_last, self.step_size[0]):
+ win = self.__get_patch(x, (row, col))
+ sub_patches.append(win)
+ if h_flag and w_flag:
+ ptx = (im_h - self.patch_size[0], im_w - self.patch_size[1])
+ win = self.__get_patch(x, ptx)
+ sub_patches.append(win)
+ return sub_patches
+
+ def __extract_mirror(self, x):
+ """Extracted patches with mirror padding the boundary such that the
+ central region of each patch is always within the orginal (non-padded)
+ image while all patches' central region cover the whole orginal image.
+
+ Args:
+ x : input image, should be of shape HWC
+ patch_size : a tuple of (h, w)
+ step_size : a tuple of (h, w)
+ Return:
+ a list of sub patches, each patch is same dtype as x
+
+ """
+ diff_h = self.patch_size[0] - self.step_size[0]
+ padt = diff_h // 2
+ padb = diff_h - padt
+
+ diff_w = self.patch_size[1] - self.step_size[1]
+ padl = diff_w // 2
+ padr = diff_w - padl
+
+ pad_type = "reflect"
+ x = np.lib.pad(x, ((padt, padb), (padl, padr), (0, 0)), pad_type)
+ sub_patches = self.__extract_valid(x)
+ return sub_patches
+
+ def extract(self, x, patch_type):
+ patch_type = patch_type.lower()
+ self.patch_type = patch_type
+ if patch_type == "valid":
+ return self.__extract_valid(x)
+ elif patch_type == "mirror":
+ return self.__extract_mirror(x)
+ else:
+ assert False, "Unknown Patch Type [%s]" % patch_type
+ return
+
+
+def main(cfg):
+ xtractor = PatchExtractor(cfg["patch_size"], cfg["step_size"])
+
+ for phase in ["Train", "Test"]:
+ img_dir = os.path.join(cfg["root"], f"{phase}/Images")
+ ann_dir = os.path.join(cfg["root"], f"{phase}/Labels")
+
+ file_list = glob.glob(os.path.join(ann_dir, f"*mat"))
+ file_list.sort() # ensure same ordering across platform
+
+ out_dir = f"{cfg['root']}/Prepared/{phase}"
+ if os.path.isdir(out_dir):
+ shutil.rmtree(out_dir)
+ os.makedirs(out_dir)
+
+ pbar_format = "Process File: |{bar}| {n_fmt}/{total_fmt}[{elapsed}<{remaining},{rate_fmt}]"
+ pbarx = tqdm.tqdm(
+ total=len(file_list), bar_format=pbar_format, ascii=True, position=0
+ )
+
+ for file_path in file_list:
+ base_name = pathlib.Path(file_path).stem
+
+ img = load_img(f"{img_dir}/{base_name}.png")
+ ann = load_ann(f"{ann_dir}/{base_name}.mat")
+
+ # *
+ img = np.concatenate([img, ann], axis=-1)
+ sub_patches = xtractor.extract(img, cfg["extract_type"])
+
+ pbar_format = "Extracting : |{bar}| {n_fmt}/{total_fmt}[{elapsed}<{remaining},{rate_fmt}]"
+ pbar = tqdm.tqdm(
+ total=len(sub_patches),
+ leave=False,
+ bar_format=pbar_format,
+ ascii=True,
+ position=1,
+ )
+
+ for idx, patch in enumerate(sub_patches):
+ image_patch = patch[..., :3].transpose(2, 0, 1) # make channel first
+ inst_map_patch = patch[..., 3][None] # add channel dim
+ type_map_patch = patch[..., 4][None] # add channel dim
+ np.save("{0}/{1}_{2:03d}_image.npy".format(out_dir, base_name, idx), image_patch)
+ np.save("{0}/{1}_{2:03d}_inst_map.npy".format(out_dir, base_name, idx), inst_map_patch)
+ np.save("{0}/{1}_{2:03d}_type_map.npy".format(out_dir, base_name, idx), type_map_patch)
+ pbar.update()
+ pbar.close()
+ # *
+
+ pbarx.update()
+ pbarx.close()
+
+
+def parse_arguments():
+ parser = ArgumentParser(description="Extract patches from the original images")
+
+ parser.add_argument(
+ "--root",
+ type=str,
+ default="/home/yunliu/Workspace/Data/CoNSeP",
+ help="root path to image folder containing training/test",
+ )
+ parser.add_argument("--type", type=str, default="mirror", dest="extract_type", help="Choose 'mirror' or 'valid'")
+ parser.add_argument("--ps", nargs='+', type=int, default=[540, 540], dest="patch_size", help="patch size")
+ parser.add_argument("--ss", nargs='+', type=int, default=[164, 164], dest="step_size", help="patch size")
+ args = parser.parse_args()
+ config_dict = vars(args)
+
+ return config_dict
+
+
+if __name__ == "__main__":
+ cfg = parse_arguments()
+
+ main(cfg)
diff --git a/pathology/hovernet/training.py b/pathology/hovernet/training.py
new file mode 100644
index 0000000000..d402dc5e0b
--- /dev/null
+++ b/pathology/hovernet/training.py
@@ -0,0 +1,365 @@
+import os
+import glob
+import time
+import logging
+import torch
+import numpy as np
+import torch.distributed as dist
+from argparse import ArgumentParser
+from monai.data import DataLoader, partition_dataset, CacheDataset
+from monai.networks.nets import HoVerNet
+from monai.engines import SupervisedEvaluator, SupervisedTrainer
+from monai.transforms import (
+ LoadImaged,
+ TorchVisiond,
+ Lambdad,
+ Activationsd,
+ OneOf,
+ MedianSmoothd,
+ AsDiscreted,
+ Compose,
+ CastToTyped,
+ ComputeHoVerMapsd,
+ ScaleIntensityRanged,
+ RandGaussianNoised,
+ RandFlipd,
+ RandAffined,
+ RandGaussianSmoothd,
+ CenterSpatialCropd,
+)
+from monai.handlers import (
+ MeanDice,
+ CheckpointSaver,
+ LrScheduleHandler,
+ StatsHandler,
+ TensorBoardStatsHandler,
+ ValidationHandler,
+ from_engine,
+)
+from monai.utils import set_determinism
+from monai.utils.enums import HoVerNetBranch
+from monai.apps.pathology.handlers.utils import from_engine_hovernet
+from monai.apps.pathology.engines.utils import PrepareBatchHoVerNet
+from monai.apps.pathology.losses import HoVerNetLoss
+from skimage import measure
+
+
+def create_log_dir(cfg):
+ timestamp = time.strftime("%y%m%d-%H%M")
+ run_folder_name = (
+ f"{timestamp}_hovernet_bs{cfg['batch_size']}_ep{cfg['n_epochs']}_lr{cfg['lr']}_seed{cfg['seed']}_stage{cfg['stage']}"
+ )
+ log_dir = os.path.join(cfg["logdir"], run_folder_name)
+ print(f"Logs and model are saved at '{log_dir}'.")
+ if not os.path.exists(log_dir):
+ os.makedirs(log_dir, exist_ok=True)
+ return log_dir
+
+
+def prepare_data(data_dir, phase):
+ # prepare datalist
+ images = sorted(
+ glob.glob(os.path.join(data_dir, f"{phase}/*image.npy")))
+ inst_maps = sorted(
+ glob.glob(os.path.join(data_dir, f"{phase}/*inst_map.npy")))
+ type_maps = sorted(
+ glob.glob(os.path.join(data_dir, f"{phase}/*type_map.npy")))
+
+ data_dicts = [
+ {"image": _image, "label_inst": _inst_map, "label_type": _type_map}
+ for _image, _inst_map, _type_map in zip(images, inst_maps, type_maps)
+ ]
+
+ return data_dicts
+
+
+def get_loaders(cfg, train_transforms, val_transforms):
+ multi_gpu = True if torch.cuda.device_count() > 1 else False
+
+ train_data = prepare_data(cfg["root"], "Train")
+ valid_data = prepare_data(cfg["root"], "Test")
+ if multi_gpu:
+ train_data = partition_dataset(
+ data=train_data,
+ num_partitions=dist.get_world_size(),
+ even_divisible=True,
+ shuffle=True,
+ seed=cfg["seed"],
+ )[dist.get_rank()]
+ valid_data = partition_dataset(
+ data=valid_data,
+ num_partitions=dist.get_world_size(),
+ even_divisible=True,
+ shuffle=False,
+ seed=cfg["seed"],
+ )[dist.get_rank()]
+
+ print("train_files:", len(train_data))
+ print("val_files:", len(valid_data))
+
+ train_ds = CacheDataset(data=train_data, transform=train_transforms, cache_rate=1.0, num_workers=4)
+ valid_ds = CacheDataset(data=valid_data, transform=val_transforms, cache_rate=1.0, num_workers=4)
+ train_loader = DataLoader(
+ train_ds,
+ batch_size=cfg["batch_size"],
+ num_workers=cfg["num_workers"],
+ shuffle=True,
+ pin_memory=torch.cuda.is_available()
+ )
+ val_loader = DataLoader(
+ valid_ds,
+ batch_size=cfg["batch_size"],
+ num_workers=cfg["num_workers"],
+ pin_memory=torch.cuda.is_available()
+ )
+
+ return train_loader, val_loader
+
+
+def create_model(cfg, device):
+ # Each user is responsible for checking the content of models/datasets and the applicable licenses and
+ # determining if suitable for the intended use.
+ # The license for the below pre-trained model is different than MONAI license.
+ # Please check the source where these weights are obtained from:
+ # https://github.com/vqdang/hover_net#data-format
+ pretrained_model = "https://drive.google.com/u/1/uc?id=1KntZge40tAHgyXmHYVqZZ5d2p_4Qr2l5&export=download"
+ if cfg["stage"] == 0:
+ model = HoVerNet(
+ mode=cfg["mode"],
+ in_channels=3,
+ out_classes=cfg["out_classes"],
+ act=("relu", {"inplace": True}),
+ norm="batch",
+ pretrained_url=pretrained_model,
+ freeze_encoder=True,
+ ).to(device)
+ print(f'stage{cfg["stage"]} start!')
+ else:
+ model = HoVerNet(
+ mode=cfg["mode"],
+ in_channels=3,
+ out_classes=cfg["out_classes"],
+ act=("relu", {"inplace": True}),
+ norm="batch",
+ pretrained_url=None,
+ freeze_encoder=False,
+ ).to(device)
+ model.load_state_dict(torch.load(cfg["ckpt_path"])['net'])
+ print(f'stage{cfg["stage"]}, success load weight!')
+
+ return model
+
+
+def run(log_dir, cfg):
+ set_determinism(seed=cfg["seed"])
+
+ if cfg["mode"].lower() == "original":
+ cfg["patch_size"] = [270, 270]
+ cfg["out_size"] = [80, 80]
+ elif cfg["mode"].lower() == "fast":
+ cfg["patch_size"] = [256, 256]
+ cfg["out_size"] = [164, 164]
+
+ multi_gpu = True if torch.cuda.device_count() > 1 else False
+ if multi_gpu:
+ dist.init_process_group(backend="nccl", init_method="env://")
+ device = torch.device("cuda:{}".format(dist.get_rank()))
+ torch.cuda.set_device(device)
+ else:
+ device = torch.device("cuda" if cfg["use_gpu"] else "cpu")
+
+ # --------------------------------------------------------------------------
+ # Data Loading and Preprocessing
+ # --------------------------------------------------------------------------
+ # __________________________________________________________________________
+ # __________________________________________________________________________
+ # Build MONAI preprocessing
+ train_transforms = Compose(
+ [
+ LoadImaged(keys=["image", "label_inst", "label_type"], image_only=True),
+ Lambdad(keys="label_inst", func=lambda x: measure.label(x)),
+ RandAffined(
+ keys=["image", "label_inst", "label_type"],
+ prob=1.0,
+ rotate_range=((np.pi), 0),
+ scale_range=((0.2), (0.2)),
+ shear_range=((0.05), (0.05)),
+ translate_range=((6), (6)),
+ padding_mode="zeros",
+ mode=("nearest"),
+ ),
+ CenterSpatialCropd(
+ keys="image",
+ roi_size=cfg["patch_size"],
+ ),
+ RandFlipd(keys=["image", "label_inst", "label_type"], prob=0.5, spatial_axis=0),
+ RandFlipd(keys=["image", "label_inst", "label_type"], prob=0.5, spatial_axis=1),
+ OneOf(transforms=[
+ RandGaussianSmoothd(keys=["image"], sigma_x=(0.1, 1.1), sigma_y=(0.1, 1.1), prob=1.0),
+ MedianSmoothd(keys=["image"], radius=1),
+ RandGaussianNoised(keys=["image"], prob=1.0, std=0.05)
+ ]),
+ CastToTyped(keys="image", dtype=np.uint8),
+ TorchVisiond(
+ keys=["image"],
+ name="ColorJitter",
+ brightness=(229 / 255.0, 281 / 255.0),
+ contrast=(0.95, 1.10),
+ saturation=(0.8, 1.2),
+ hue=(-0.04, 0.04)
+ ),
+ AsDiscreted(keys=["label_type"], to_onehot=[5]),
+ ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
+ CastToTyped(keys="label_inst", dtype=torch.int),
+ ComputeHoVerMapsd(keys="label_inst"),
+ Lambdad(keys="label_inst", func=lambda x: x > 0, overwrite="label"),
+ CenterSpatialCropd(
+ keys=["label", "hover_label_inst", "label_inst", "label_type"],
+ roi_size=cfg["out_size"],
+ ),
+ AsDiscreted(keys=["label"], to_onehot=2),
+ CastToTyped(keys=["image", "label_inst", "label_type"], dtype=torch.float32),
+ ]
+ )
+ val_transforms = Compose(
+ [
+ LoadImaged(keys=["image", "label_inst", "label_type"], image_only=True),
+ Lambdad(keys="label_inst", func=lambda x: measure.label(x)),
+ CastToTyped(keys=["image", "label_inst"], dtype=torch.int),
+ CenterSpatialCropd(
+ keys="image",
+ roi_size=cfg["patch_size"],
+ ),
+ ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
+ ComputeHoVerMapsd(keys="label_inst"),
+ Lambdad(keys="label_inst", func=lambda x: x > 0, overwrite="label"),
+ CenterSpatialCropd(
+ keys=["label", "hover_label_inst", "label_inst", "label_type"],
+ roi_size=cfg["out_size"],
+ ),
+ CastToTyped(keys=["image", "label_inst", "label_type"], dtype=torch.float32),
+ ]
+ )
+
+ # __________________________________________________________________________
+ # Create MONAI DataLoaders
+ train_loader, val_loader = get_loaders(cfg, train_transforms, val_transforms)
+
+ # --------------------------------------------------------------------------
+ # Create Model, Loss, Optimizer, lr_scheduler
+ # --------------------------------------------------------------------------
+ # __________________________________________________________________________
+ # initialize model
+ model = create_model(cfg, device)
+ if multi_gpu:
+ model = torch.nn.parallel.DistributedDataParallel(
+ model, device_ids=[dist.get_rank()], output_device=dist.get_rank()
+ )
+ loss_function = HoVerNetLoss(lambda_hv_mse=1.0)
+ optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg["lr"], weight_decay=1e-5)
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25)
+ post_process_np = Compose([
+ Activationsd(keys=HoVerNetBranch.NP.value, softmax=True),
+ AsDiscreted(keys=HoVerNetBranch.NP.value, argmax=True),
+ ])
+ post_process = Lambdad(keys="pred", func=post_process_np)
+
+ # --------------------------------------------
+ # Ignite Trainer/Evaluator
+ # --------------------------------------------
+ # Evaluator
+ val_handlers = [
+ CheckpointSaver(
+ save_dir=log_dir,
+ save_dict={"net": model},
+ save_key_metric=True,
+ ),
+ StatsHandler(output_transform=lambda x: None),
+ TensorBoardStatsHandler(log_dir=log_dir, output_transform=lambda x: None),
+ ]
+ if multi_gpu:
+ val_handlers = val_handlers if dist.get_rank() == 0 else None
+ evaluator = SupervisedEvaluator(
+ device=device,
+ val_data_loader=val_loader,
+ prepare_batch=PrepareBatchHoVerNet(extra_keys=['label_type', 'hover_label_inst']),
+ network=model,
+ postprocessing=post_process,
+ key_val_metric={"val_dice": MeanDice(include_background=False, output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value))},
+ val_handlers=val_handlers,
+ amp=cfg["amp"],
+ )
+
+ # Trainer
+ train_handlers = [
+ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
+ ValidationHandler(validator=evaluator, interval=cfg["val_freq"], epoch_level=True),
+ CheckpointSaver(
+ save_dir=log_dir,
+ save_dict={"net": model, "opt": optimizer},
+ save_interval=cfg["save_interval"],
+ epoch_level=True,
+ ),
+ StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)),
+ TensorBoardStatsHandler(
+ log_dir=log_dir, tag_name="train_loss", output_transform=from_engine(["loss"], first=True)
+ ),
+ ]
+ if multi_gpu:
+ train_handlers = train_handlers if dist.get_rank() == 0 else train_handlers[:2]
+ trainer = SupervisedTrainer(
+ device=device,
+ max_epochs=cfg["n_epochs"],
+ train_data_loader=train_loader,
+ prepare_batch=PrepareBatchHoVerNet(extra_keys=['label_type', 'hover_label_inst']),
+ network=model,
+ optimizer=optimizer,
+ loss_function=loss_function,
+ postprocessing=post_process,
+ key_train_metric={"train_dice": MeanDice(include_background=False, output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value))},
+ train_handlers=train_handlers,
+ amp=cfg["amp"],
+ )
+ trainer.run()
+
+ if multi_gpu:
+ dist.destroy_process_group()
+
+
+def main():
+ parser = ArgumentParser(description="Tumor detection on whole slide pathology images.")
+ parser.add_argument(
+ "--root",
+ type=str,
+ default="/workspace/Data/CoNSeP/Prepared",
+ help="root data dir",
+ )
+ parser.add_argument("--logdir", type=str, default="./logs/", dest="logdir", help="log directory")
+ parser.add_argument("-s", "--seed", type=int, default=24)
+
+ parser.add_argument("--bs", type=int, default=16, dest="batch_size", help="batch size")
+ parser.add_argument("--ep", type=int, default=3, dest="n_epochs", help="number of epochs")
+ parser.add_argument("--lr", type=float, default=1e-4, dest="lr", help="initial learning rate")
+ parser.add_argument("--step", type=int, default=25, dest="step_size", help="period of learning rate decay")
+ parser.add_argument("-f", "--val_freq", type=int, default=1, help="validation frequence")
+ parser.add_argument("--stage", type=int, default=0, dest="stage", help="training stage")
+ parser.add_argument("--no-amp", action="store_false", dest="amp", help="deactivate amp")
+ parser.add_argument("--classes", type=int, default=5, dest="out_classes", help="output classes")
+ parser.add_argument("--mode", type=str, default="original", help="choose either `original` or `fast`")
+
+ parser.add_argument("--save_interval", type=int, default=10)
+ parser.add_argument("--cpu", type=int, default=8, dest="num_workers", help="number of workers")
+ parser.add_argument("--no-gpu", action="store_false", dest="use_gpu", help="deactivate use of gpu")
+ parser.add_argument("--ckpt", type=str, dest="ckpt_path", help="checkpoint path")
+
+ args = parser.parse_args()
+ cfg = vars(args)
+ print(cfg)
+
+ logging.basicConfig(level=logging.INFO)
+ log_dir = create_log_dir(cfg)
+ run(log_dir, cfg)
+
+
+if __name__ == "__main__":
+ main()