diff --git a/pathology/hovernet/README.MD b/pathology/hovernet/README.MD index fef6268eb2..b5bdc4ac58 100644 --- a/pathology/hovernet/README.MD +++ b/pathology/hovernet/README.MD @@ -3,17 +3,18 @@ This folder contains ignite version examples to run train and validate a HoVerNet model. It also has torch version notebooks to run training and evaluation.

- hovernet scheme implementation based on: -Simon Graham et al., HoVer-Net: Simultaneous Segmentation and Classification of Nuclei in Multi-Tissue Histology Images.' Medical Image Analysis, (2019). https://arxiv.org/abs/1812.06499 +Simon Graham et al., HoVer-Net: Simultaneous Segmentation and Classification of Nuclei in Multi-Tissue Histology Images.' Medical Image Analysis, (2019). ### 1. Data -CoNSeP datasets which are used in the examples can be downloaded from https://warwick.ac.uk/fac/cross_fac/tia/data/hovernet/. -- First download CoNSeP dataset to `data_root`. -- Run prepare_patches.py to prepare patches from images. +CoNSeP datasets which are used in the examples can be downloaded from . + +- First download CoNSeP dataset to `DATA_ROOT` (default is `"/workspace/Data/Pathology/CoNSeP"`). +- Run `python prepare_patches.py` to prepare patches from images. ### 2. Questions and bugs @@ -21,71 +22,96 @@ CoNSeP datasets which are used in the examples can be downloaded from https://wa - For bugs relating to MONAI functionality, please create an issue on the [main repository](https://github.com/Project-MONAI/MONAI/issues). - For bugs relating to the running of a tutorial, please create an issue in [this repository](https://github.com/Project-MONAI/Tutorials/issues). - ### 3. List of notebooks and examples + #### [Prepare Your Data](./prepare_patches.py) -This example is used to prepare patches from tiles referring to the implementation from https://github.com/vqdang/hover_net/blob/master/extract_patches.py. Prepared patches will be saved in `data_root`/Prepared. + +This example is used to prepare patches from tiles referring to the implementation from . Prepared patches will be saved in `DATA_ROOT`/Prepared. ```bash -# Run to know all possible options +# Run to get all possible arguments python ./prepare_patches.py -h -# Prepare patches from images +# Prepare patches from images using default arguments +python ./prepare_patches.py + +# Prepare patch to use custom arguments python ./prepare_patches.py \ - --root `data_root` + --root `DATA_ROOT` \ + --ps 540 540 \ + --ss 164 164 ``` #### [HoVerNet Training](./training.py) + This example uses MONAI workflow to train a HoVerNet model on prepared CoNSeP dataset. -Since HoVerNet is training via a two-stage approach. First initialised the model with pre-trained weights on the [ImageNet dataset](https://ieeexplore.ieee.org/document/5206848), trained only the decoders for the first 50 epochs, and then fine-tuned all layers for another 50 epochs. We need to specify `--stage` during training. +Since HoVerNet is training via a two-stage approach. First initialized the model with pre-trained weights on the [ImageNet dataset](https://ieeexplore.ieee.org/document/5206848), trained only the decoders for the first 50 epochs, and then fine-tuned all layers for another 50 epochs. We need to specify `--stage` during training. Each user is responsible for checking the content of models/datasets and the applicable licenses and determining if suitable for the intended use. The license for the pre-trained model used in examples is different than MONAI license. Please check the source where these weights are obtained from: -https://github.com/vqdang/hover_net#data-format + +If you didn't use the default value in data preparation, set ``--root `DATA_ROOT`/Prepared`` for each of the training commands. ```bash -# Run to know all possible options +# Run to get all possible arguments python ./training.py -h -# Train a hovernet model on single-gpu(replace with your own ckpt path) +# Train a HoVerNet model on single-GPU or CPU-only (replace with your own ckpt path) export CUDA_VISIBLE_DEVICES=0; python training.py \ - --ep 50 \ --stage 0 \ + --ep 50 \ --bs 16 \ - --root `save_root` + --log-dir ./logs export CUDA_VISIBLE_DEVICES=0; python training.py \ - --ep 50 \ --stage 1 \ - --bs 4 \ - --root `save_root` \ - --ckpt logs/stage0/checkpoint_epoch=50.pt - -# Train a hovernet model on multi-gpu (NVIDIA)(replace with your own ckpt path) -torchrun --nnodes=1 --nproc_per_node=2 training.py \ - --ep 50 \ - --bs 8 \ - --root `save_root` \ - --stage 0 -torchrun --nnodes=1 --nproc_per_node=2 training.py \ - --ep 50 \ - --bs 2 \ - --root `save_root` \ - --stage 1 \ - --ckpt logs/stage0/checkpoint_epoch=50.pt + --ep 50 \ + --bs 16 \ + --log-dir ./logs \ + --ckpt logs/stage0/model.pt + +# Train a HoVerNet model on multi-GPU with default arguments +torchrun --nnodes=1 --nproc_per_node=2 training.py +torchrun --nnodes=1 --nproc_per_node=2 training.py --stage 1 ``` #### [HoVerNet Validation](./evaluation.py) + This example uses MONAI workflow to evaluate the trained HoVerNet model on prepared test data from CoNSeP dataset. With their metrics on original mode. We reproduce the results with Dice: 0.82762; PQ: 0.48976; F1d: 0.73592. + ```bash -# Run to know all possible options +# Run to get all possible arguments python ./evaluation.py -h -# Evaluate a HoVerNet model -python ./evaluation.py +# Evaluate a HoVerNet model on single-GPU or CPU-only +python ./evaluation.py \ --root `save_root` \ - --ckpt logs/stage0/checkpoint_epoch=50.pt + --ckpt logs/stage0/model.pt + +# Evaluate a HoVerNet model on multi-GPU with default arguments +torchrun --nnodes=1 --nproc_per_node=2 evaluation.py +``` + +#### [HoVerNet Inference](./inference.py) + +This example uses MONAI workflow to run inference for HoVerNet model on arbitrary sized region of interest. +Under the hood, it will use a sliding window approach to run inference on overlapping patches and then put the results +of the inference together and makes an output image the same size as the input. Then it will run the post-processing on +this output image and create the final results. This example save the instance map and type map as png files but it can +be modified to save any output of interest. + +```bash +# Run to get all possible arguments +python ./inference.py -h + +# Run HoVerNet inference on single-GPU or CPU-only +python ./inference.py \ + --root `save_root` \ + --ckpt logs/stage0/model.pt + +# Run HoVerNet inference on multi-GPU with default arguments +torchrun --nnodes=1 --nproc_per_node=2 ./inference.py ``` ## Disclaimer diff --git a/pathology/hovernet/evaluation.py b/pathology/hovernet/evaluation.py index 476da843ca..723c4e45d8 100644 --- a/pathology/hovernet/evaluation.py +++ b/pathology/hovernet/evaluation.py @@ -28,21 +28,18 @@ def prepare_data(data_dir, phase): - data_dir = os.path.join(data_dir, phase) + """prepare data list""" - images = list(sorted( - glob.glob(os.path.join(data_dir, "*/*image.npy")))) - inst_maps = list(sorted( - glob.glob(os.path.join(data_dir, "*/*inst_map.npy")))) - type_maps = list(sorted( - glob.glob(os.path.join(data_dir, "*/*type_map.npy")))) + data_dir = os.path.join(data_dir, phase) + images = sorted(glob.glob(os.path.join(data_dir, "*image.npy"))) + inst_maps = sorted(glob.glob(os.path.join(data_dir, "*inst_map.npy"))) + type_maps = sorted(glob.glob(os.path.join(data_dir, "*type_map.npy"))) - data_dicts = [ + data_list = [ {"image": _image, "label_inst": _inst_map, "label_type": _type_map} for _image, _inst_map, _type_map in zip(images, inst_maps, type_maps) ] - - return data_dicts + return data_list def run(cfg): @@ -75,13 +72,10 @@ def run(cfg): ) # Create MONAI DataLoaders - valid_data = prepare_data(cfg["root"], "valid") + valid_data = prepare_data(cfg["root"], "Test") valid_ds = CacheDataset(data=valid_data, transform=val_transforms, cache_rate=1.0, num_workers=4) val_loader = DataLoader( - valid_ds, - batch_size=cfg["batch_size"], - num_workers=cfg["num_workers"], - pin_memory=torch.cuda.is_available() + valid_ds, batch_size=cfg["batch_size"], num_workers=cfg["num_workers"], pin_memory=torch.cuda.is_available() ) # initialize model @@ -95,23 +89,31 @@ def run(cfg): freeze_encoder=False, ).to(device) - post_process_np = Compose([ - Activationsd(keys=HoVerNetBranch.NP.value, softmax=True), - Lambdad(keys=HoVerNetBranch.NP.value, func=lambda x: x[1: 2, ...] > 0.5)]) + post_process_np = Compose( + [ + Activationsd(keys=HoVerNetBranch.NP.value, softmax=True), + Lambdad(keys=HoVerNetBranch.NP.value, func=lambda x: x[1:2, ...] > 0.5), + ] + ) post_process = Lambdad(keys="pred", func=post_process_np) # Evaluator val_handlers = [ - CheckpointLoader(load_path=cfg["ckpt_path"], load_dict={"net": model}), + CheckpointLoader(load_path=cfg["ckpt"], load_dict={"net": model}), StatsHandler(output_transform=lambda x: None), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, - prepare_batch=PrepareBatchHoVerNet(extra_keys=['label_type', 'hover_label_inst']), + prepare_batch=PrepareBatchHoVerNet(extra_keys=["label_type", "hover_label_inst"]), network=model, postprocessing=post_process, - key_val_metric={"val_dice": MeanDice(include_background=False, output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value))}, + key_val_metric={ + "val_dice": MeanDice( + include_background=False, + output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value), + ) + }, val_handlers=val_handlers, amp=cfg["amp"], ) @@ -125,10 +127,15 @@ def main(): parser.add_argument( "--root", type=str, - default="/workspace/Data/CoNSeP/Prepared/consep", + default="/workspace/Data/Pathology/CoNSeP/Prepared", help="root data dir", ) - + parser.add_argument( + "--ckpt", + type=str, + default="./logs/model.pt", + help="Path to the pytorch checkpoint", + ) parser.add_argument("--bs", type=int, default=16, dest="batch_size", help="batch size") parser.add_argument("--no-amp", action="store_false", dest="amp", help="deactivate amp") parser.add_argument("--classes", type=int, default=5, dest="out_classes", help="output classes") @@ -136,7 +143,6 @@ def main(): parser.add_argument("--cpu", type=int, default=8, dest="num_workers", help="number of workers") parser.add_argument("--use_gpu", type=bool, default=True, dest="use_gpu", help="whether to use gpu") - parser.add_argument("--ckpt", type=str, dest="ckpt_path", help="checkpoint path") args = parser.parse_args() cfg = vars(args) diff --git a/pathology/hovernet/inference.py b/pathology/hovernet/inference.py new file mode 100644 index 0000000000..5af7ef9d82 --- /dev/null +++ b/pathology/hovernet/inference.py @@ -0,0 +1,213 @@ +import logging +import os +import time +from argparse import ArgumentParser +from glob import glob + +import torch +import torch.distributed as dist + +from monai.apps.pathology.inferers import SlidingWindowHoVerNetInferer +from monai.apps.pathology.transforms import ( + HoVerNetInstanceMapPostProcessingd, + HoVerNetNuclearTypePostProcessingd, +) +from monai.data import DataLoader, Dataset, PILReader, partition_dataset +from monai.engines import SupervisedEvaluator +from monai.networks.nets import HoVerNet +from monai.transforms import ( + CastToTyped, + Compose, + EnsureChannelFirstd, + FromMetaTensord, + LoadImaged, + FlattenSubKeysd, + SaveImaged, + ScaleIntensityRanged, +) +from monai.utils import HoVerNetBranch, first + + +def create_output_dir(cfg): + output_dir = cfg["output"] + print(f"Outputs are saved at '{output_dir}'.") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + return output_dir + + +def run(cfg): + # -------------------------------------------------------------------------- + # Set Directory, Device, + # -------------------------------------------------------------------------- + output_dir = create_output_dir(cfg) + multi_gpu = cfg["use_gpu"] if torch.cuda.device_count() > 1 else False + if multi_gpu: + dist.init_process_group(backend="nccl", init_method="env://") + device = torch.device("cuda:{}".format(dist.get_rank())) + torch.cuda.set_device(device) + if dist.get_rank() == 0: + print(f"Running multi-gpu with {dist.get_world_size()} GPUs") + else: + device = torch.device("cuda" if cfg["use_gpu"] and torch.cuda.is_available() else "cpu") + # -------------------------------------------------------------------------- + # Transforms + # -------------------------------------------------------------------------- + # Preprocessing transforms + pre_transforms = Compose( + [ + LoadImaged(keys="image", reader=PILReader, converter=lambda x: x.convert("RGB")), + EnsureChannelFirstd(keys="image"), + CastToTyped(keys="image", dtype=torch.float32), + ScaleIntensityRanged(keys="image", a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True), + ] + ) + # Postprocessing transforms + post_transforms = Compose( + [ + FlattenSubKeysd( + keys="pred", + sub_keys=[HoVerNetBranch.NC.value, HoVerNetBranch.NP.value, HoVerNetBranch.HV.value], + delete_keys=True, + ), + HoVerNetInstanceMapPostProcessingd(sobel_kernel_size=3, marker_threshold=0.7, marker_radius=2), + HoVerNetNuclearTypePostProcessingd(), + FromMetaTensord(keys=["image"]), + SaveImaged( + keys="instance_map", + meta_keys="image_meta_dict", + output_ext="png", + output_dir=output_dir, + output_postfix="instance_map", + output_dtype="uint8", + separate_folder=False, + ), + SaveImaged( + keys="type_map", + meta_keys="image_meta_dict", + output_ext="png", + output_dir=output_dir, + output_postfix="type_map", + output_dtype="uint8", + separate_folder=False, + ), + ] + ) + # -------------------------------------------------------------------------- + # Data and Data Loading + # -------------------------------------------------------------------------- + # List of whole slide images + data_list = [{"image": image} for image in glob(os.path.join(cfg["root"], "*.png"))] + + if multi_gpu: + data = partition_dataset(data=data_list, num_partitions=dist.get_world_size())[dist.get_rank()] + else: + data = data_list + + # Dataset + dataset = Dataset(data, transform=pre_transforms) + + # Dataloader + data_loader = DataLoader(dataset, num_workers=cfg["ncpu"], batch_size=cfg["batch_size"], pin_memory=True) + + # -------------------------------------------------------------------------- + # Run some sanity checks + # -------------------------------------------------------------------------- + # Check first sample + first_sample = first(data_loader) + if first_sample is None: + raise ValueError("First sample is None!") + print("image: ") + print(" shape", first_sample["image"].shape) + print(" type: ", type(first_sample["image"])) + print(" dtype: ", first_sample["image"].dtype) + print(f"batch size: {cfg['batch_size']}") + print(f"number of batches: {len(data_loader)}") + + # -------------------------------------------------------------------------- + # Model + # -------------------------------------------------------------------------- + # Create model and load weights + model = HoVerNet( + mode=cfg["mode"], + in_channels=3, + out_classes=cfg["out_classes"], + ).to(device) + model.load_state_dict(torch.load(cfg["ckpt"], map_location=device)["net"]) + model.eval() + if multi_gpu: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[dist.get_rank()], output_device=dist.get_rank() + ) + + # -------------------------------------------- + # Inference + # -------------------------------------------- + # Inference engine + sliding_inferer = SlidingWindowHoVerNetInferer( + roi_size=cfg["patch_size"], + sw_batch_size=cfg["sw_batch_size"], + overlap=1.0 - float(cfg["out_size"]) / float(cfg["patch_size"]), + padding_mode="constant", + cval=0, + sw_device=device, + progress=True, + extra_input_padding=((cfg["patch_size"] - cfg["out_size"]) // 2,) * 4, + ) + + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=data_loader, + network=model, + postprocessing=post_transforms, + inferer=sliding_inferer, + amp=cfg["use_amp"], + ) + evaluator.run() + + if multi_gpu: + dist.destroy_process_group() + + +def main(): + logging.basicConfig(level=logging.INFO) + + parser = ArgumentParser(description="Tumor detection on whole slide pathology images.") + parser.add_argument( + "--root", + type=str, + default="/workspace/Data/Pathology/CoNSeP/Test/Images", + help="Images root dir", + ) + parser.add_argument("--output", type=str, default="./eval/", dest="output", help="log directory") + parser.add_argument( + "--ckpt", + type=str, + default="./logs/model.pt", + help="Path to the pytorch checkpoint", + ) + parser.add_argument("--mode", type=str, default="original", help="HoVerNet mode (original/fast)") + parser.add_argument("--out-classes", type=int, default=5, help="number of output classes") + parser.add_argument("--bs", type=int, default=1, dest="batch_size", help="batch size") + parser.add_argument("--swbs", type=int, default=8, dest="sw_batch_size", help="sliding window batch size") + parser.add_argument("--no-amp", action="store_false", dest="use_amp", help="deactivate use of amp") + parser.add_argument("--no-gpu", action="store_false", dest="use_gpu", help="deactivate use of gpu") + parser.add_argument("--ncpu", type=int, default=0, help="number of CPU workers") + args = parser.parse_args() + + config_dict = vars(args) + if config_dict["mode"].lower() == "original": + config_dict["patch_size"] = 270 + config_dict["out_size"] = 80 + elif config_dict["mode"].lower() == "fast": + config_dict["patch_size"] = 256 + config_dict["out_size"] = 164 + else: + raise ValueError("`--mode` should be either `original` or `fast`.") + + print(config_dict) + run(config_dict) + + +if __name__ == "__main__": + main() diff --git a/pathology/hovernet/prepare_patches.py b/pathology/hovernet/prepare_patches.py index bf8e90f9a4..73deca8973 100644 --- a/pathology/hovernet/prepare_patches.py +++ b/pathology/hovernet/prepare_patches.py @@ -1,14 +1,14 @@ -import os -import math -import tqdm import glob -import shutil +import math +import os import pathlib +import shutil +from argparse import ArgumentParser import numpy as np import scipy.io as sio +import tqdm from PIL import Image -from argparse import ArgumentParser def load_img(path): @@ -30,7 +30,7 @@ def load_ann(path): return ann -class PatchExtractor(): +class PatchExtractor: """Extractor to generate patches with or without padding. Turn on debug mode to see how it is done. @@ -42,9 +42,9 @@ class PatchExtractor(): a list of sub patches, each patch has dtype same as x Examples: - >>> xtractor = PatchExtractor((450, 450), (120, 120)) + >>> extractor = PatchExtractor((450, 450), (120, 120)) >>> img = np.full([1200, 1200, 3], 255, np.uint8) - >>> patches = xtractor.extract(img, 'mirror') + >>> patches = extractor.extract(img, 'mirror') """ @@ -166,9 +166,7 @@ def main(cfg): os.makedirs(out_dir) pbar_format = "Process File: |{bar}| {n_fmt}/{total_fmt}[{elapsed}<{remaining},{rate_fmt}]" - pbarx = tqdm.tqdm( - total=len(file_list), bar_format=pbar_format, ascii=True, position=0 - ) + pbarx = tqdm.tqdm(total=len(file_list), bar_format=pbar_format, ascii=True, position=0) for file_path in file_list: base_name = pathlib.Path(file_path).stem @@ -210,12 +208,12 @@ def parse_arguments(): parser.add_argument( "--root", type=str, - default="/home/yunliu/Workspace/Data/CoNSeP", + default="/workspace/Data/Pathology/CoNSeP", help="root path to image folder containing training/test", ) parser.add_argument("--type", type=str, default="mirror", dest="extract_type", help="Choose 'mirror' or 'valid'") - parser.add_argument("--ps", nargs='+', type=int, default=[540, 540], dest="patch_size", help="patch size") - parser.add_argument("--ss", nargs='+', type=int, default=[164, 164], dest="step_size", help="patch size") + parser.add_argument("--ps", nargs="+", type=int, default=[540, 540], dest="patch_size", help="patch size") + parser.add_argument("--ss", nargs="+", type=int, default=[164, 164], dest="step_size", help="patch size") args = parser.parse_args() config_dict = vars(args) diff --git a/pathology/hovernet/training.py b/pathology/hovernet/training.py index d402dc5e0b..6c79e9f74c 100644 --- a/pathology/hovernet/training.py +++ b/pathology/hovernet/training.py @@ -45,32 +45,28 @@ def create_log_dir(cfg): - timestamp = time.strftime("%y%m%d-%H%M") - run_folder_name = ( - f"{timestamp}_hovernet_bs{cfg['batch_size']}_ep{cfg['n_epochs']}_lr{cfg['lr']}_seed{cfg['seed']}_stage{cfg['stage']}" - ) - log_dir = os.path.join(cfg["logdir"], run_folder_name) - print(f"Logs and model are saved at '{log_dir}'.") + log_dir = cfg["log_dir"] + if cfg["stage"] == 0: + log_dir = os.path.join(log_dir, "stage0") + print(f"Logs and models are saved at '{log_dir}'.") if not os.path.exists(log_dir): os.makedirs(log_dir, exist_ok=True) return log_dir def prepare_data(data_dir, phase): - # prepare datalist - images = sorted( - glob.glob(os.path.join(data_dir, f"{phase}/*image.npy"))) - inst_maps = sorted( - glob.glob(os.path.join(data_dir, f"{phase}/*inst_map.npy"))) - type_maps = sorted( - glob.glob(os.path.join(data_dir, f"{phase}/*type_map.npy"))) - - data_dicts = [ + """prepare data list""" + + data_dir = os.path.join(data_dir, phase) + images = sorted(glob.glob(os.path.join(data_dir, "*image.npy"))) + inst_maps = sorted(glob.glob(os.path.join(data_dir, "*inst_map.npy"))) + type_maps = sorted(glob.glob(os.path.join(data_dir, "*type_map.npy"))) + + data_list = [ {"image": _image, "label_inst": _inst_map, "label_type": _type_map} for _image, _inst_map, _type_map in zip(images, inst_maps, type_maps) ] - - return data_dicts + return data_list def get_loaders(cfg, train_transforms, val_transforms): @@ -104,13 +100,10 @@ def get_loaders(cfg, train_transforms, val_transforms): batch_size=cfg["batch_size"], num_workers=cfg["num_workers"], shuffle=True, - pin_memory=torch.cuda.is_available() + pin_memory=torch.cuda.is_available(), ) val_loader = DataLoader( - valid_ds, - batch_size=cfg["batch_size"], - num_workers=cfg["num_workers"], - pin_memory=torch.cuda.is_available() + valid_ds, batch_size=cfg["batch_size"], num_workers=cfg["num_workers"], pin_memory=torch.cuda.is_available() ) return train_loader, val_loader @@ -144,7 +137,7 @@ def create_model(cfg, device): pretrained_url=None, freeze_encoder=False, ).to(device) - model.load_state_dict(torch.load(cfg["ckpt_path"])['net']) + model.load_state_dict(torch.load(cfg["ckpt"])["net"]) print(f'stage{cfg["stage"]}, success load weight!') return model @@ -172,7 +165,6 @@ def run(log_dir, cfg): # Data Loading and Preprocessing # -------------------------------------------------------------------------- # __________________________________________________________________________ - # __________________________________________________________________________ # Build MONAI preprocessing train_transforms = Compose( [ @@ -194,11 +186,13 @@ def run(log_dir, cfg): ), RandFlipd(keys=["image", "label_inst", "label_type"], prob=0.5, spatial_axis=0), RandFlipd(keys=["image", "label_inst", "label_type"], prob=0.5, spatial_axis=1), - OneOf(transforms=[ - RandGaussianSmoothd(keys=["image"], sigma_x=(0.1, 1.1), sigma_y=(0.1, 1.1), prob=1.0), - MedianSmoothd(keys=["image"], radius=1), - RandGaussianNoised(keys=["image"], prob=1.0, std=0.05) - ]), + OneOf( + transforms=[ + RandGaussianSmoothd(keys=["image"], sigma_x=(0.1, 1.1), sigma_y=(0.1, 1.1), prob=1.0), + MedianSmoothd(keys=["image"], radius=1), + RandGaussianNoised(keys=["image"], prob=1.0, std=0.05), + ] + ), CastToTyped(keys="image", dtype=np.uint8), TorchVisiond( keys=["image"], @@ -206,7 +200,7 @@ def run(log_dir, cfg): brightness=(229 / 255.0, 281 / 255.0), contrast=(0.95, 1.10), saturation=(0.8, 1.2), - hue=(-0.04, 0.04) + hue=(-0.04, 0.04), ), AsDiscreted(keys=["label_type"], to_onehot=[5]), ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True), @@ -258,10 +252,12 @@ def run(log_dir, cfg): loss_function = HoVerNetLoss(lambda_hv_mse=1.0) optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg["lr"], weight_decay=1e-5) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25) - post_process_np = Compose([ - Activationsd(keys=HoVerNetBranch.NP.value, softmax=True), - AsDiscreted(keys=HoVerNetBranch.NP.value, argmax=True), - ]) + post_process_np = Compose( + [ + Activationsd(keys=HoVerNetBranch.NP.value, softmax=True), + AsDiscreted(keys=HoVerNetBranch.NP.value, argmax=True), + ] + ) post_process = Lambdad(keys="pred", func=post_process_np) # -------------------------------------------- @@ -282,10 +278,15 @@ def run(log_dir, cfg): evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, - prepare_batch=PrepareBatchHoVerNet(extra_keys=['label_type', 'hover_label_inst']), + prepare_batch=PrepareBatchHoVerNet(extra_keys=["label_type", "hover_label_inst"]), network=model, postprocessing=post_process, - key_val_metric={"val_dice": MeanDice(include_background=False, output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value))}, + key_val_metric={ + "val_dice": MeanDice( + include_background=False, + output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value), + ) + }, val_handlers=val_handlers, amp=cfg["amp"], ) @@ -298,6 +299,8 @@ def run(log_dir, cfg): save_dir=log_dir, save_dict={"net": model, "opt": optimizer}, save_interval=cfg["save_interval"], + save_final=True, + final_filename="model.pt", epoch_level=True, ), StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), @@ -311,12 +314,17 @@ def run(log_dir, cfg): device=device, max_epochs=cfg["n_epochs"], train_data_loader=train_loader, - prepare_batch=PrepareBatchHoVerNet(extra_keys=['label_type', 'hover_label_inst']), + prepare_batch=PrepareBatchHoVerNet(extra_keys=["label_type", "hover_label_inst"]), network=model, optimizer=optimizer, loss_function=loss_function, postprocessing=post_process, - key_train_metric={"train_dice": MeanDice(include_background=False, output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value))}, + key_train_metric={ + "train_dice": MeanDice( + include_background=False, + output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value), + ) + }, train_handlers=train_handlers, amp=cfg["amp"], ) @@ -331,18 +339,18 @@ def main(): parser.add_argument( "--root", type=str, - default="/workspace/Data/CoNSeP/Prepared", + default="/workspace/Data/Pathology/CoNSeP/Prepared", help="root data dir", ) - parser.add_argument("--logdir", type=str, default="./logs/", dest="logdir", help="log directory") + parser.add_argument("--log-dir", type=str, default="./logs/", help="log directory") parser.add_argument("-s", "--seed", type=int, default=24) parser.add_argument("--bs", type=int, default=16, dest="batch_size", help="batch size") - parser.add_argument("--ep", type=int, default=3, dest="n_epochs", help="number of epochs") + parser.add_argument("--ep", type=int, default=50, dest="n_epochs", help="number of epochs") parser.add_argument("--lr", type=float, default=1e-4, dest="lr", help="initial learning rate") parser.add_argument("--step", type=int, default=25, dest="step_size", help="period of learning rate decay") parser.add_argument("-f", "--val_freq", type=int, default=1, help="validation frequence") - parser.add_argument("--stage", type=int, default=0, dest="stage", help="training stage") + parser.add_argument("--stage", type=int, default=0, help="training stage") parser.add_argument("--no-amp", action="store_false", dest="amp", help="deactivate amp") parser.add_argument("--classes", type=int, default=5, dest="out_classes", help="output classes") parser.add_argument("--mode", type=str, default="original", help="choose either `original` or `fast`") @@ -350,10 +358,12 @@ def main(): parser.add_argument("--save_interval", type=int, default=10) parser.add_argument("--cpu", type=int, default=8, dest="num_workers", help="number of workers") parser.add_argument("--no-gpu", action="store_false", dest="use_gpu", help="deactivate use of gpu") - parser.add_argument("--ckpt", type=str, dest="ckpt_path", help="checkpoint path") + parser.add_argument("--ckpt", type=str, dest="ckpt", help="model checkpoint path") args = parser.parse_args() cfg = vars(args) + if cfg["stage"] == 1 and not cfg["ckpt"] and cfg["log_dir"]: + cfg["ckpt"] = os.path.join(cfg["log_dir"], "stage0", "model.pt") print(cfg) logging.basicConfig(level=logging.INFO)