diff --git a/pathology/hovernet/README.MD b/pathology/hovernet/README.MD new file mode 100644 index 0000000000..fef6268eb2 --- /dev/null +++ b/pathology/hovernet/README.MD @@ -0,0 +1,93 @@ +# HoVerNet Examples + +This folder contains ignite version examples to run train and validate a HoVerNet model. +It also has torch version notebooks to run training and evaluation. +

+ hovernet scheme +implementation based on: + +Simon Graham et al., HoVer-Net: Simultaneous Segmentation and Classification of Nuclei in Multi-Tissue Histology Images.' Medical Image Analysis, (2019). https://arxiv.org/abs/1812.06499 + +### 1. Data + +CoNSeP datasets which are used in the examples can be downloaded from https://warwick.ac.uk/fac/cross_fac/tia/data/hovernet/. +- First download CoNSeP dataset to `data_root`. +- Run prepare_patches.py to prepare patches from images. + +### 2. Questions and bugs + +- For questions relating to the use of MONAI, please us our [Discussions tab](https://github.com/Project-MONAI/MONAI/discussions) on the main repository of MONAI. +- For bugs relating to MONAI functionality, please create an issue on the [main repository](https://github.com/Project-MONAI/MONAI/issues). +- For bugs relating to the running of a tutorial, please create an issue in [this repository](https://github.com/Project-MONAI/Tutorials/issues). + + +### 3. List of notebooks and examples +#### [Prepare Your Data](./prepare_patches.py) +This example is used to prepare patches from tiles referring to the implementation from https://github.com/vqdang/hover_net/blob/master/extract_patches.py. Prepared patches will be saved in `data_root`/Prepared. + +```bash +# Run to know all possible options +python ./prepare_patches.py -h + +# Prepare patches from images +python ./prepare_patches.py \ + --root `data_root` +``` + +#### [HoVerNet Training](./training.py) +This example uses MONAI workflow to train a HoVerNet model on prepared CoNSeP dataset. +Since HoVerNet is training via a two-stage approach. First initialised the model with pre-trained weights on the [ImageNet dataset](https://ieeexplore.ieee.org/document/5206848), trained only the decoders for the first 50 epochs, and then fine-tuned all layers for another 50 epochs. We need to specify `--stage` during training. + +Each user is responsible for checking the content of models/datasets and the applicable licenses and determining if suitable for the intended use. +The license for the pre-trained model used in examples is different than MONAI license. Please check the source where these weights are obtained from: +https://github.com/vqdang/hover_net#data-format + + +```bash +# Run to know all possible options +python ./training.py -h + +# Train a hovernet model on single-gpu(replace with your own ckpt path) +export CUDA_VISIBLE_DEVICES=0; python training.py \ + --ep 50 \ + --stage 0 \ + --bs 16 \ + --root `save_root` +export CUDA_VISIBLE_DEVICES=0; python training.py \ + --ep 50 \ + --stage 1 \ + --bs 4 \ + --root `save_root` \ + --ckpt logs/stage0/checkpoint_epoch=50.pt + +# Train a hovernet model on multi-gpu (NVIDIA)(replace with your own ckpt path) +torchrun --nnodes=1 --nproc_per_node=2 training.py \ + --ep 50 \ + --bs 8 \ + --root `save_root` \ + --stage 0 +torchrun --nnodes=1 --nproc_per_node=2 training.py \ + --ep 50 \ + --bs 2 \ + --root `save_root` \ + --stage 1 \ + --ckpt logs/stage0/checkpoint_epoch=50.pt +``` + +#### [HoVerNet Validation](./evaluation.py) +This example uses MONAI workflow to evaluate the trained HoVerNet model on prepared test data from CoNSeP dataset. +With their metrics on original mode. We reproduce the results with Dice: 0.82762; PQ: 0.48976; F1d: 0.73592. +```bash +# Run to know all possible options +python ./evaluation.py -h + +# Evaluate a HoVerNet model +python ./evaluation.py + --root `save_root` \ + --ckpt logs/stage0/checkpoint_epoch=50.pt +``` + +## Disclaimer + +This is an example, not to be used for diagnostic purposes. diff --git a/pathology/hovernet/evaluation.py b/pathology/hovernet/evaluation.py new file mode 100644 index 0000000000..476da843ca --- /dev/null +++ b/pathology/hovernet/evaluation.py @@ -0,0 +1,150 @@ +import os +import glob +import logging +import torch +from argparse import ArgumentParser +from monai.data import DataLoader, CacheDataset +from monai.networks.nets import HoVerNet +from monai.engines import SupervisedEvaluator +from monai.transforms import ( + LoadImaged, + Lambdad, + Activationsd, + Compose, + CastToTyped, + ComputeHoVerMapsd, + ScaleIntensityRanged, + CenterSpatialCropd, +) +from monai.handlers import ( + MeanDice, + StatsHandler, + CheckpointLoader, +) +from monai.utils.enums import HoVerNetBranch +from monai.apps.pathology.handlers.utils import from_engine_hovernet +from monai.apps.pathology.engines.utils import PrepareBatchHoVerNet +from skimage import measure + + +def prepare_data(data_dir, phase): + data_dir = os.path.join(data_dir, phase) + + images = list(sorted( + glob.glob(os.path.join(data_dir, "*/*image.npy")))) + inst_maps = list(sorted( + glob.glob(os.path.join(data_dir, "*/*inst_map.npy")))) + type_maps = list(sorted( + glob.glob(os.path.join(data_dir, "*/*type_map.npy")))) + + data_dicts = [ + {"image": _image, "label_inst": _inst_map, "label_type": _type_map} + for _image, _inst_map, _type_map in zip(images, inst_maps, type_maps) + ] + + return data_dicts + + +def run(cfg): + if cfg["mode"].lower() == "original": + cfg["patch_size"] = [270, 270] + cfg["out_size"] = [80, 80] + elif cfg["mode"].lower() == "fast": + cfg["patch_size"] = [256, 256] + cfg["out_size"] = [164, 164] + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + val_transforms = Compose( + [ + LoadImaged(keys=["image", "label_inst", "label_type"], image_only=True), + Lambdad(keys="label_inst", func=lambda x: measure.label(x)), + CastToTyped(keys=["image", "label_inst"], dtype=torch.int), + CenterSpatialCropd( + keys="image", + roi_size=cfg["patch_size"], + ), + ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True), + ComputeHoVerMapsd(keys="label_inst"), + Lambdad(keys="label_inst", func=lambda x: x > 0, overwrite="label"), + CenterSpatialCropd( + keys=["label", "hover_label_inst", "label_inst", "label_type"], + roi_size=cfg["out_size"], + ), + CastToTyped(keys=["image", "label_inst", "label_type"], dtype=torch.float32), + ] + ) + + # Create MONAI DataLoaders + valid_data = prepare_data(cfg["root"], "valid") + valid_ds = CacheDataset(data=valid_data, transform=val_transforms, cache_rate=1.0, num_workers=4) + val_loader = DataLoader( + valid_ds, + batch_size=cfg["batch_size"], + num_workers=cfg["num_workers"], + pin_memory=torch.cuda.is_available() + ) + + # initialize model + model = HoVerNet( + mode=cfg["mode"], + in_channels=3, + out_classes=cfg["out_classes"], + act=("relu", {"inplace": True}), + norm="batch", + pretrained_url=None, + freeze_encoder=False, + ).to(device) + + post_process_np = Compose([ + Activationsd(keys=HoVerNetBranch.NP.value, softmax=True), + Lambdad(keys=HoVerNetBranch.NP.value, func=lambda x: x[1: 2, ...] > 0.5)]) + post_process = Lambdad(keys="pred", func=post_process_np) + + # Evaluator + val_handlers = [ + CheckpointLoader(load_path=cfg["ckpt_path"], load_dict={"net": model}), + StatsHandler(output_transform=lambda x: None), + ] + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=val_loader, + prepare_batch=PrepareBatchHoVerNet(extra_keys=['label_type', 'hover_label_inst']), + network=model, + postprocessing=post_process, + key_val_metric={"val_dice": MeanDice(include_background=False, output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value))}, + val_handlers=val_handlers, + amp=cfg["amp"], + ) + + state = evaluator.run() + print(state) + + +def main(): + parser = ArgumentParser(description="Tumor detection on whole slide pathology images.") + parser.add_argument( + "--root", + type=str, + default="/workspace/Data/CoNSeP/Prepared/consep", + help="root data dir", + ) + + parser.add_argument("--bs", type=int, default=16, dest="batch_size", help="batch size") + parser.add_argument("--no-amp", action="store_false", dest="amp", help="deactivate amp") + parser.add_argument("--classes", type=int, default=5, dest="out_classes", help="output classes") + parser.add_argument("--mode", type=str, default="original", help="choose either `original` or `fast`") + + parser.add_argument("--cpu", type=int, default=8, dest="num_workers", help="number of workers") + parser.add_argument("--use_gpu", type=bool, default=True, dest="use_gpu", help="whether to use gpu") + parser.add_argument("--ckpt", type=str, dest="ckpt_path", help="checkpoint path") + + args = parser.parse_args() + cfg = vars(args) + print(cfg) + + logging.basicConfig(level=logging.INFO) + run(cfg) + + +if __name__ == "__main__": + main() diff --git a/pathology/hovernet/prepare_patches.py b/pathology/hovernet/prepare_patches.py new file mode 100644 index 0000000000..bf8e90f9a4 --- /dev/null +++ b/pathology/hovernet/prepare_patches.py @@ -0,0 +1,228 @@ +import os +import math +import tqdm +import glob +import shutil +import pathlib + +import numpy as np +import scipy.io as sio +from PIL import Image +from argparse import ArgumentParser + + +def load_img(path): + return np.array(Image.open(path).convert("RGB")) + + +def load_ann(path): + # assumes that ann is HxW + ann_inst = sio.loadmat(path)["inst_map"] + ann_type = sio.loadmat(path)["type_map"] + + # merge classes for CoNSeP (utilise 3 nuclei classes and background keep the same with paper) + ann_type[(ann_type == 3) | (ann_type == 4)] = 3 + ann_type[(ann_type == 5) | (ann_type == 6) | (ann_type == 7)] = 4 + + ann = np.dstack([ann_inst, ann_type]) + ann = ann.astype("int32") + + return ann + + +class PatchExtractor(): + """Extractor to generate patches with or without padding. + Turn on debug mode to see how it is done. + + Args: + x : input image, should be of shape HWC + patch_size : a tuple of (h, w) + step_size : a tuple of (h, w) + Return: + a list of sub patches, each patch has dtype same as x + + Examples: + >>> xtractor = PatchExtractor((450, 450), (120, 120)) + >>> img = np.full([1200, 1200, 3], 255, np.uint8) + >>> patches = xtractor.extract(img, 'mirror') + + """ + + def __init__(self, patch_size, step_size): + self.patch_type = "mirror" + self.patch_size = patch_size + self.step_size = step_size + + def __get_patch(self, x, ptx): + pty = (ptx[0] + self.patch_size[0], ptx[1] + self.patch_size[1]) + win = x[ptx[0] : pty[0], ptx[1] : pty[1]] + assert ( + win.shape[0] == self.patch_size[0] and win.shape[1] == self.patch_size[1] + ), "[BUG] Incorrect Patch Size {0}".format(win.shape) + return win + + def __extract_valid(self, x): + """Extracted patches without padding, only work in case patch_size > step_size. + + Note: to deal with the remaining portions which are at the boundary a.k.a + those which do not fit when slide left->right, top->bottom), we flip + the sliding direction then extract 1 patch starting from right / bottom edge. + There will be 1 additional patch extracted at the bottom-right corner. + + Args: + x : input image, should be of shape HWC + patch_size : a tuple of (h, w) + step_size : a tuple of (h, w) + Return: + a list of sub patches, each patch is same dtype as x + + """ + im_h = x.shape[0] + im_w = x.shape[1] + + def extract_infos(length, patch_size, step_size): + flag = (length - patch_size) % step_size != 0 + last_step = math.floor((length - patch_size) / step_size) + last_step = (last_step + 1) * step_size + return flag, last_step + + h_flag, h_last = extract_infos(im_h, self.patch_size[0], self.step_size[0]) + w_flag, w_last = extract_infos(im_w, self.patch_size[1], self.step_size[1]) + + sub_patches = [] + # Deal with valid block + for row in range(0, h_last, self.step_size[0]): + for col in range(0, w_last, self.step_size[1]): + win = self.__get_patch(x, (row, col)) + sub_patches.append(win) + # Deal with edge case + if h_flag: + row = im_h - self.patch_size[0] + for col in range(0, w_last, self.step_size[1]): + win = self.__get_patch(x, (row, col)) + sub_patches.append(win) + if w_flag: + col = im_w - self.patch_size[1] + for row in range(0, h_last, self.step_size[0]): + win = self.__get_patch(x, (row, col)) + sub_patches.append(win) + if h_flag and w_flag: + ptx = (im_h - self.patch_size[0], im_w - self.patch_size[1]) + win = self.__get_patch(x, ptx) + sub_patches.append(win) + return sub_patches + + def __extract_mirror(self, x): + """Extracted patches with mirror padding the boundary such that the + central region of each patch is always within the orginal (non-padded) + image while all patches' central region cover the whole orginal image. + + Args: + x : input image, should be of shape HWC + patch_size : a tuple of (h, w) + step_size : a tuple of (h, w) + Return: + a list of sub patches, each patch is same dtype as x + + """ + diff_h = self.patch_size[0] - self.step_size[0] + padt = diff_h // 2 + padb = diff_h - padt + + diff_w = self.patch_size[1] - self.step_size[1] + padl = diff_w // 2 + padr = diff_w - padl + + pad_type = "reflect" + x = np.lib.pad(x, ((padt, padb), (padl, padr), (0, 0)), pad_type) + sub_patches = self.__extract_valid(x) + return sub_patches + + def extract(self, x, patch_type): + patch_type = patch_type.lower() + self.patch_type = patch_type + if patch_type == "valid": + return self.__extract_valid(x) + elif patch_type == "mirror": + return self.__extract_mirror(x) + else: + assert False, "Unknown Patch Type [%s]" % patch_type + return + + +def main(cfg): + xtractor = PatchExtractor(cfg["patch_size"], cfg["step_size"]) + + for phase in ["Train", "Test"]: + img_dir = os.path.join(cfg["root"], f"{phase}/Images") + ann_dir = os.path.join(cfg["root"], f"{phase}/Labels") + + file_list = glob.glob(os.path.join(ann_dir, f"*mat")) + file_list.sort() # ensure same ordering across platform + + out_dir = f"{cfg['root']}/Prepared/{phase}" + if os.path.isdir(out_dir): + shutil.rmtree(out_dir) + os.makedirs(out_dir) + + pbar_format = "Process File: |{bar}| {n_fmt}/{total_fmt}[{elapsed}<{remaining},{rate_fmt}]" + pbarx = tqdm.tqdm( + total=len(file_list), bar_format=pbar_format, ascii=True, position=0 + ) + + for file_path in file_list: + base_name = pathlib.Path(file_path).stem + + img = load_img(f"{img_dir}/{base_name}.png") + ann = load_ann(f"{ann_dir}/{base_name}.mat") + + # * + img = np.concatenate([img, ann], axis=-1) + sub_patches = xtractor.extract(img, cfg["extract_type"]) + + pbar_format = "Extracting : |{bar}| {n_fmt}/{total_fmt}[{elapsed}<{remaining},{rate_fmt}]" + pbar = tqdm.tqdm( + total=len(sub_patches), + leave=False, + bar_format=pbar_format, + ascii=True, + position=1, + ) + + for idx, patch in enumerate(sub_patches): + image_patch = patch[..., :3].transpose(2, 0, 1) # make channel first + inst_map_patch = patch[..., 3][None] # add channel dim + type_map_patch = patch[..., 4][None] # add channel dim + np.save("{0}/{1}_{2:03d}_image.npy".format(out_dir, base_name, idx), image_patch) + np.save("{0}/{1}_{2:03d}_inst_map.npy".format(out_dir, base_name, idx), inst_map_patch) + np.save("{0}/{1}_{2:03d}_type_map.npy".format(out_dir, base_name, idx), type_map_patch) + pbar.update() + pbar.close() + # * + + pbarx.update() + pbarx.close() + + +def parse_arguments(): + parser = ArgumentParser(description="Extract patches from the original images") + + parser.add_argument( + "--root", + type=str, + default="/home/yunliu/Workspace/Data/CoNSeP", + help="root path to image folder containing training/test", + ) + parser.add_argument("--type", type=str, default="mirror", dest="extract_type", help="Choose 'mirror' or 'valid'") + parser.add_argument("--ps", nargs='+', type=int, default=[540, 540], dest="patch_size", help="patch size") + parser.add_argument("--ss", nargs='+', type=int, default=[164, 164], dest="step_size", help="patch size") + args = parser.parse_args() + config_dict = vars(args) + + return config_dict + + +if __name__ == "__main__": + cfg = parse_arguments() + + main(cfg) diff --git a/pathology/hovernet/training.py b/pathology/hovernet/training.py new file mode 100644 index 0000000000..d402dc5e0b --- /dev/null +++ b/pathology/hovernet/training.py @@ -0,0 +1,365 @@ +import os +import glob +import time +import logging +import torch +import numpy as np +import torch.distributed as dist +from argparse import ArgumentParser +from monai.data import DataLoader, partition_dataset, CacheDataset +from monai.networks.nets import HoVerNet +from monai.engines import SupervisedEvaluator, SupervisedTrainer +from monai.transforms import ( + LoadImaged, + TorchVisiond, + Lambdad, + Activationsd, + OneOf, + MedianSmoothd, + AsDiscreted, + Compose, + CastToTyped, + ComputeHoVerMapsd, + ScaleIntensityRanged, + RandGaussianNoised, + RandFlipd, + RandAffined, + RandGaussianSmoothd, + CenterSpatialCropd, +) +from monai.handlers import ( + MeanDice, + CheckpointSaver, + LrScheduleHandler, + StatsHandler, + TensorBoardStatsHandler, + ValidationHandler, + from_engine, +) +from monai.utils import set_determinism +from monai.utils.enums import HoVerNetBranch +from monai.apps.pathology.handlers.utils import from_engine_hovernet +from monai.apps.pathology.engines.utils import PrepareBatchHoVerNet +from monai.apps.pathology.losses import HoVerNetLoss +from skimage import measure + + +def create_log_dir(cfg): + timestamp = time.strftime("%y%m%d-%H%M") + run_folder_name = ( + f"{timestamp}_hovernet_bs{cfg['batch_size']}_ep{cfg['n_epochs']}_lr{cfg['lr']}_seed{cfg['seed']}_stage{cfg['stage']}" + ) + log_dir = os.path.join(cfg["logdir"], run_folder_name) + print(f"Logs and model are saved at '{log_dir}'.") + if not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + return log_dir + + +def prepare_data(data_dir, phase): + # prepare datalist + images = sorted( + glob.glob(os.path.join(data_dir, f"{phase}/*image.npy"))) + inst_maps = sorted( + glob.glob(os.path.join(data_dir, f"{phase}/*inst_map.npy"))) + type_maps = sorted( + glob.glob(os.path.join(data_dir, f"{phase}/*type_map.npy"))) + + data_dicts = [ + {"image": _image, "label_inst": _inst_map, "label_type": _type_map} + for _image, _inst_map, _type_map in zip(images, inst_maps, type_maps) + ] + + return data_dicts + + +def get_loaders(cfg, train_transforms, val_transforms): + multi_gpu = True if torch.cuda.device_count() > 1 else False + + train_data = prepare_data(cfg["root"], "Train") + valid_data = prepare_data(cfg["root"], "Test") + if multi_gpu: + train_data = partition_dataset( + data=train_data, + num_partitions=dist.get_world_size(), + even_divisible=True, + shuffle=True, + seed=cfg["seed"], + )[dist.get_rank()] + valid_data = partition_dataset( + data=valid_data, + num_partitions=dist.get_world_size(), + even_divisible=True, + shuffle=False, + seed=cfg["seed"], + )[dist.get_rank()] + + print("train_files:", len(train_data)) + print("val_files:", len(valid_data)) + + train_ds = CacheDataset(data=train_data, transform=train_transforms, cache_rate=1.0, num_workers=4) + valid_ds = CacheDataset(data=valid_data, transform=val_transforms, cache_rate=1.0, num_workers=4) + train_loader = DataLoader( + train_ds, + batch_size=cfg["batch_size"], + num_workers=cfg["num_workers"], + shuffle=True, + pin_memory=torch.cuda.is_available() + ) + val_loader = DataLoader( + valid_ds, + batch_size=cfg["batch_size"], + num_workers=cfg["num_workers"], + pin_memory=torch.cuda.is_available() + ) + + return train_loader, val_loader + + +def create_model(cfg, device): + # Each user is responsible for checking the content of models/datasets and the applicable licenses and + # determining if suitable for the intended use. + # The license for the below pre-trained model is different than MONAI license. + # Please check the source where these weights are obtained from: + # https://github.com/vqdang/hover_net#data-format + pretrained_model = "https://drive.google.com/u/1/uc?id=1KntZge40tAHgyXmHYVqZZ5d2p_4Qr2l5&export=download" + if cfg["stage"] == 0: + model = HoVerNet( + mode=cfg["mode"], + in_channels=3, + out_classes=cfg["out_classes"], + act=("relu", {"inplace": True}), + norm="batch", + pretrained_url=pretrained_model, + freeze_encoder=True, + ).to(device) + print(f'stage{cfg["stage"]} start!') + else: + model = HoVerNet( + mode=cfg["mode"], + in_channels=3, + out_classes=cfg["out_classes"], + act=("relu", {"inplace": True}), + norm="batch", + pretrained_url=None, + freeze_encoder=False, + ).to(device) + model.load_state_dict(torch.load(cfg["ckpt_path"])['net']) + print(f'stage{cfg["stage"]}, success load weight!') + + return model + + +def run(log_dir, cfg): + set_determinism(seed=cfg["seed"]) + + if cfg["mode"].lower() == "original": + cfg["patch_size"] = [270, 270] + cfg["out_size"] = [80, 80] + elif cfg["mode"].lower() == "fast": + cfg["patch_size"] = [256, 256] + cfg["out_size"] = [164, 164] + + multi_gpu = True if torch.cuda.device_count() > 1 else False + if multi_gpu: + dist.init_process_group(backend="nccl", init_method="env://") + device = torch.device("cuda:{}".format(dist.get_rank())) + torch.cuda.set_device(device) + else: + device = torch.device("cuda" if cfg["use_gpu"] else "cpu") + + # -------------------------------------------------------------------------- + # Data Loading and Preprocessing + # -------------------------------------------------------------------------- + # __________________________________________________________________________ + # __________________________________________________________________________ + # Build MONAI preprocessing + train_transforms = Compose( + [ + LoadImaged(keys=["image", "label_inst", "label_type"], image_only=True), + Lambdad(keys="label_inst", func=lambda x: measure.label(x)), + RandAffined( + keys=["image", "label_inst", "label_type"], + prob=1.0, + rotate_range=((np.pi), 0), + scale_range=((0.2), (0.2)), + shear_range=((0.05), (0.05)), + translate_range=((6), (6)), + padding_mode="zeros", + mode=("nearest"), + ), + CenterSpatialCropd( + keys="image", + roi_size=cfg["patch_size"], + ), + RandFlipd(keys=["image", "label_inst", "label_type"], prob=0.5, spatial_axis=0), + RandFlipd(keys=["image", "label_inst", "label_type"], prob=0.5, spatial_axis=1), + OneOf(transforms=[ + RandGaussianSmoothd(keys=["image"], sigma_x=(0.1, 1.1), sigma_y=(0.1, 1.1), prob=1.0), + MedianSmoothd(keys=["image"], radius=1), + RandGaussianNoised(keys=["image"], prob=1.0, std=0.05) + ]), + CastToTyped(keys="image", dtype=np.uint8), + TorchVisiond( + keys=["image"], + name="ColorJitter", + brightness=(229 / 255.0, 281 / 255.0), + contrast=(0.95, 1.10), + saturation=(0.8, 1.2), + hue=(-0.04, 0.04) + ), + AsDiscreted(keys=["label_type"], to_onehot=[5]), + ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True), + CastToTyped(keys="label_inst", dtype=torch.int), + ComputeHoVerMapsd(keys="label_inst"), + Lambdad(keys="label_inst", func=lambda x: x > 0, overwrite="label"), + CenterSpatialCropd( + keys=["label", "hover_label_inst", "label_inst", "label_type"], + roi_size=cfg["out_size"], + ), + AsDiscreted(keys=["label"], to_onehot=2), + CastToTyped(keys=["image", "label_inst", "label_type"], dtype=torch.float32), + ] + ) + val_transforms = Compose( + [ + LoadImaged(keys=["image", "label_inst", "label_type"], image_only=True), + Lambdad(keys="label_inst", func=lambda x: measure.label(x)), + CastToTyped(keys=["image", "label_inst"], dtype=torch.int), + CenterSpatialCropd( + keys="image", + roi_size=cfg["patch_size"], + ), + ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True), + ComputeHoVerMapsd(keys="label_inst"), + Lambdad(keys="label_inst", func=lambda x: x > 0, overwrite="label"), + CenterSpatialCropd( + keys=["label", "hover_label_inst", "label_inst", "label_type"], + roi_size=cfg["out_size"], + ), + CastToTyped(keys=["image", "label_inst", "label_type"], dtype=torch.float32), + ] + ) + + # __________________________________________________________________________ + # Create MONAI DataLoaders + train_loader, val_loader = get_loaders(cfg, train_transforms, val_transforms) + + # -------------------------------------------------------------------------- + # Create Model, Loss, Optimizer, lr_scheduler + # -------------------------------------------------------------------------- + # __________________________________________________________________________ + # initialize model + model = create_model(cfg, device) + if multi_gpu: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[dist.get_rank()], output_device=dist.get_rank() + ) + loss_function = HoVerNetLoss(lambda_hv_mse=1.0) + optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg["lr"], weight_decay=1e-5) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25) + post_process_np = Compose([ + Activationsd(keys=HoVerNetBranch.NP.value, softmax=True), + AsDiscreted(keys=HoVerNetBranch.NP.value, argmax=True), + ]) + post_process = Lambdad(keys="pred", func=post_process_np) + + # -------------------------------------------- + # Ignite Trainer/Evaluator + # -------------------------------------------- + # Evaluator + val_handlers = [ + CheckpointSaver( + save_dir=log_dir, + save_dict={"net": model}, + save_key_metric=True, + ), + StatsHandler(output_transform=lambda x: None), + TensorBoardStatsHandler(log_dir=log_dir, output_transform=lambda x: None), + ] + if multi_gpu: + val_handlers = val_handlers if dist.get_rank() == 0 else None + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=val_loader, + prepare_batch=PrepareBatchHoVerNet(extra_keys=['label_type', 'hover_label_inst']), + network=model, + postprocessing=post_process, + key_val_metric={"val_dice": MeanDice(include_background=False, output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value))}, + val_handlers=val_handlers, + amp=cfg["amp"], + ) + + # Trainer + train_handlers = [ + LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), + ValidationHandler(validator=evaluator, interval=cfg["val_freq"], epoch_level=True), + CheckpointSaver( + save_dir=log_dir, + save_dict={"net": model, "opt": optimizer}, + save_interval=cfg["save_interval"], + epoch_level=True, + ), + StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), + TensorBoardStatsHandler( + log_dir=log_dir, tag_name="train_loss", output_transform=from_engine(["loss"], first=True) + ), + ] + if multi_gpu: + train_handlers = train_handlers if dist.get_rank() == 0 else train_handlers[:2] + trainer = SupervisedTrainer( + device=device, + max_epochs=cfg["n_epochs"], + train_data_loader=train_loader, + prepare_batch=PrepareBatchHoVerNet(extra_keys=['label_type', 'hover_label_inst']), + network=model, + optimizer=optimizer, + loss_function=loss_function, + postprocessing=post_process, + key_train_metric={"train_dice": MeanDice(include_background=False, output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value))}, + train_handlers=train_handlers, + amp=cfg["amp"], + ) + trainer.run() + + if multi_gpu: + dist.destroy_process_group() + + +def main(): + parser = ArgumentParser(description="Tumor detection on whole slide pathology images.") + parser.add_argument( + "--root", + type=str, + default="/workspace/Data/CoNSeP/Prepared", + help="root data dir", + ) + parser.add_argument("--logdir", type=str, default="./logs/", dest="logdir", help="log directory") + parser.add_argument("-s", "--seed", type=int, default=24) + + parser.add_argument("--bs", type=int, default=16, dest="batch_size", help="batch size") + parser.add_argument("--ep", type=int, default=3, dest="n_epochs", help="number of epochs") + parser.add_argument("--lr", type=float, default=1e-4, dest="lr", help="initial learning rate") + parser.add_argument("--step", type=int, default=25, dest="step_size", help="period of learning rate decay") + parser.add_argument("-f", "--val_freq", type=int, default=1, help="validation frequence") + parser.add_argument("--stage", type=int, default=0, dest="stage", help="training stage") + parser.add_argument("--no-amp", action="store_false", dest="amp", help="deactivate amp") + parser.add_argument("--classes", type=int, default=5, dest="out_classes", help="output classes") + parser.add_argument("--mode", type=str, default="original", help="choose either `original` or `fast`") + + parser.add_argument("--save_interval", type=int, default=10) + parser.add_argument("--cpu", type=int, default=8, dest="num_workers", help="number of workers") + parser.add_argument("--no-gpu", action="store_false", dest="use_gpu", help="deactivate use of gpu") + parser.add_argument("--ckpt", type=str, dest="ckpt_path", help="checkpoint path") + + args = parser.parse_args() + cfg = vars(args) + print(cfg) + + logging.basicConfig(level=logging.INFO) + log_dir = create_log_dir(cfg) + run(log_dir, cfg) + + +if __name__ == "__main__": + main()