diff --git a/pathology/hovernet/README.MD b/pathology/hovernet/README.MD
index fef6268eb2..b5bdc4ac58 100644
--- a/pathology/hovernet/README.MD
+++ b/pathology/hovernet/README.MD
@@ -3,17 +3,18 @@
This folder contains ignite version examples to run train and validate a HoVerNet model.
It also has torch version notebooks to run training and evaluation.
-
implementation based on:
-Simon Graham et al., HoVer-Net: Simultaneous Segmentation and Classification of Nuclei in Multi-Tissue Histology Images.' Medical Image Analysis, (2019). https://arxiv.org/abs/1812.06499
+Simon Graham et al., HoVer-Net: Simultaneous Segmentation and Classification of Nuclei in Multi-Tissue Histology Images.' Medical Image Analysis, (2019).
### 1. Data
-CoNSeP datasets which are used in the examples can be downloaded from https://warwick.ac.uk/fac/cross_fac/tia/data/hovernet/.
-- First download CoNSeP dataset to `data_root`.
-- Run prepare_patches.py to prepare patches from images.
+CoNSeP datasets which are used in the examples can be downloaded from .
+
+- First download CoNSeP dataset to `DATA_ROOT` (default is `"/workspace/Data/Pathology/CoNSeP"`).
+- Run `python prepare_patches.py` to prepare patches from images.
### 2. Questions and bugs
@@ -21,71 +22,96 @@ CoNSeP datasets which are used in the examples can be downloaded from https://wa
- For bugs relating to MONAI functionality, please create an issue on the [main repository](https://github.com/Project-MONAI/MONAI/issues).
- For bugs relating to the running of a tutorial, please create an issue in [this repository](https://github.com/Project-MONAI/Tutorials/issues).
-
### 3. List of notebooks and examples
+
#### [Prepare Your Data](./prepare_patches.py)
-This example is used to prepare patches from tiles referring to the implementation from https://github.com/vqdang/hover_net/blob/master/extract_patches.py. Prepared patches will be saved in `data_root`/Prepared.
+
+This example is used to prepare patches from tiles referring to the implementation from . Prepared patches will be saved in `DATA_ROOT`/Prepared.
```bash
-# Run to know all possible options
+# Run to get all possible arguments
python ./prepare_patches.py -h
-# Prepare patches from images
+# Prepare patches from images using default arguments
+python ./prepare_patches.py
+
+# Prepare patch to use custom arguments
python ./prepare_patches.py \
- --root `data_root`
+ --root `DATA_ROOT` \
+ --ps 540 540 \
+ --ss 164 164
```
#### [HoVerNet Training](./training.py)
+
This example uses MONAI workflow to train a HoVerNet model on prepared CoNSeP dataset.
-Since HoVerNet is training via a two-stage approach. First initialised the model with pre-trained weights on the [ImageNet dataset](https://ieeexplore.ieee.org/document/5206848), trained only the decoders for the first 50 epochs, and then fine-tuned all layers for another 50 epochs. We need to specify `--stage` during training.
+Since HoVerNet is training via a two-stage approach. First initialized the model with pre-trained weights on the [ImageNet dataset](https://ieeexplore.ieee.org/document/5206848), trained only the decoders for the first 50 epochs, and then fine-tuned all layers for another 50 epochs. We need to specify `--stage` during training.
Each user is responsible for checking the content of models/datasets and the applicable licenses and determining if suitable for the intended use.
The license for the pre-trained model used in examples is different than MONAI license. Please check the source where these weights are obtained from:
-https://github.com/vqdang/hover_net#data-format
+
+If you didn't use the default value in data preparation, set ``--root `DATA_ROOT`/Prepared`` for each of the training commands.
```bash
-# Run to know all possible options
+# Run to get all possible arguments
python ./training.py -h
-# Train a hovernet model on single-gpu(replace with your own ckpt path)
+# Train a HoVerNet model on single-GPU or CPU-only (replace with your own ckpt path)
export CUDA_VISIBLE_DEVICES=0; python training.py \
- --ep 50 \
--stage 0 \
+ --ep 50 \
--bs 16 \
- --root `save_root`
+ --log-dir ./logs
export CUDA_VISIBLE_DEVICES=0; python training.py \
- --ep 50 \
--stage 1 \
- --bs 4 \
- --root `save_root` \
- --ckpt logs/stage0/checkpoint_epoch=50.pt
-
-# Train a hovernet model on multi-gpu (NVIDIA)(replace with your own ckpt path)
-torchrun --nnodes=1 --nproc_per_node=2 training.py \
- --ep 50 \
- --bs 8 \
- --root `save_root` \
- --stage 0
-torchrun --nnodes=1 --nproc_per_node=2 training.py \
- --ep 50 \
- --bs 2 \
- --root `save_root` \
- --stage 1 \
- --ckpt logs/stage0/checkpoint_epoch=50.pt
+ --ep 50 \
+ --bs 16 \
+ --log-dir ./logs \
+ --ckpt logs/stage0/model.pt
+
+# Train a HoVerNet model on multi-GPU with default arguments
+torchrun --nnodes=1 --nproc_per_node=2 training.py
+torchrun --nnodes=1 --nproc_per_node=2 training.py --stage 1
```
#### [HoVerNet Validation](./evaluation.py)
+
This example uses MONAI workflow to evaluate the trained HoVerNet model on prepared test data from CoNSeP dataset.
With their metrics on original mode. We reproduce the results with Dice: 0.82762; PQ: 0.48976; F1d: 0.73592.
+
```bash
-# Run to know all possible options
+# Run to get all possible arguments
python ./evaluation.py -h
-# Evaluate a HoVerNet model
-python ./evaluation.py
+# Evaluate a HoVerNet model on single-GPU or CPU-only
+python ./evaluation.py \
--root `save_root` \
- --ckpt logs/stage0/checkpoint_epoch=50.pt
+ --ckpt logs/stage0/model.pt
+
+# Evaluate a HoVerNet model on multi-GPU with default arguments
+torchrun --nnodes=1 --nproc_per_node=2 evaluation.py
+```
+
+#### [HoVerNet Inference](./inference.py)
+
+This example uses MONAI workflow to run inference for HoVerNet model on arbitrary sized region of interest.
+Under the hood, it will use a sliding window approach to run inference on overlapping patches and then put the results
+of the inference together and makes an output image the same size as the input. Then it will run the post-processing on
+this output image and create the final results. This example save the instance map and type map as png files but it can
+be modified to save any output of interest.
+
+```bash
+# Run to get all possible arguments
+python ./inference.py -h
+
+# Run HoVerNet inference on single-GPU or CPU-only
+python ./inference.py \
+ --root `save_root` \
+ --ckpt logs/stage0/model.pt
+
+# Run HoVerNet inference on multi-GPU with default arguments
+torchrun --nnodes=1 --nproc_per_node=2 ./inference.py
```
## Disclaimer
diff --git a/pathology/hovernet/evaluation.py b/pathology/hovernet/evaluation.py
index 476da843ca..723c4e45d8 100644
--- a/pathology/hovernet/evaluation.py
+++ b/pathology/hovernet/evaluation.py
@@ -28,21 +28,18 @@
def prepare_data(data_dir, phase):
- data_dir = os.path.join(data_dir, phase)
+ """prepare data list"""
- images = list(sorted(
- glob.glob(os.path.join(data_dir, "*/*image.npy"))))
- inst_maps = list(sorted(
- glob.glob(os.path.join(data_dir, "*/*inst_map.npy"))))
- type_maps = list(sorted(
- glob.glob(os.path.join(data_dir, "*/*type_map.npy"))))
+ data_dir = os.path.join(data_dir, phase)
+ images = sorted(glob.glob(os.path.join(data_dir, "*image.npy")))
+ inst_maps = sorted(glob.glob(os.path.join(data_dir, "*inst_map.npy")))
+ type_maps = sorted(glob.glob(os.path.join(data_dir, "*type_map.npy")))
- data_dicts = [
+ data_list = [
{"image": _image, "label_inst": _inst_map, "label_type": _type_map}
for _image, _inst_map, _type_map in zip(images, inst_maps, type_maps)
]
-
- return data_dicts
+ return data_list
def run(cfg):
@@ -75,13 +72,10 @@ def run(cfg):
)
# Create MONAI DataLoaders
- valid_data = prepare_data(cfg["root"], "valid")
+ valid_data = prepare_data(cfg["root"], "Test")
valid_ds = CacheDataset(data=valid_data, transform=val_transforms, cache_rate=1.0, num_workers=4)
val_loader = DataLoader(
- valid_ds,
- batch_size=cfg["batch_size"],
- num_workers=cfg["num_workers"],
- pin_memory=torch.cuda.is_available()
+ valid_ds, batch_size=cfg["batch_size"], num_workers=cfg["num_workers"], pin_memory=torch.cuda.is_available()
)
# initialize model
@@ -95,23 +89,31 @@ def run(cfg):
freeze_encoder=False,
).to(device)
- post_process_np = Compose([
- Activationsd(keys=HoVerNetBranch.NP.value, softmax=True),
- Lambdad(keys=HoVerNetBranch.NP.value, func=lambda x: x[1: 2, ...] > 0.5)])
+ post_process_np = Compose(
+ [
+ Activationsd(keys=HoVerNetBranch.NP.value, softmax=True),
+ Lambdad(keys=HoVerNetBranch.NP.value, func=lambda x: x[1:2, ...] > 0.5),
+ ]
+ )
post_process = Lambdad(keys="pred", func=post_process_np)
# Evaluator
val_handlers = [
- CheckpointLoader(load_path=cfg["ckpt_path"], load_dict={"net": model}),
+ CheckpointLoader(load_path=cfg["ckpt"], load_dict={"net": model}),
StatsHandler(output_transform=lambda x: None),
]
evaluator = SupervisedEvaluator(
device=device,
val_data_loader=val_loader,
- prepare_batch=PrepareBatchHoVerNet(extra_keys=['label_type', 'hover_label_inst']),
+ prepare_batch=PrepareBatchHoVerNet(extra_keys=["label_type", "hover_label_inst"]),
network=model,
postprocessing=post_process,
- key_val_metric={"val_dice": MeanDice(include_background=False, output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value))},
+ key_val_metric={
+ "val_dice": MeanDice(
+ include_background=False,
+ output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value),
+ )
+ },
val_handlers=val_handlers,
amp=cfg["amp"],
)
@@ -125,10 +127,15 @@ def main():
parser.add_argument(
"--root",
type=str,
- default="/workspace/Data/CoNSeP/Prepared/consep",
+ default="/workspace/Data/Pathology/CoNSeP/Prepared",
help="root data dir",
)
-
+ parser.add_argument(
+ "--ckpt",
+ type=str,
+ default="./logs/model.pt",
+ help="Path to the pytorch checkpoint",
+ )
parser.add_argument("--bs", type=int, default=16, dest="batch_size", help="batch size")
parser.add_argument("--no-amp", action="store_false", dest="amp", help="deactivate amp")
parser.add_argument("--classes", type=int, default=5, dest="out_classes", help="output classes")
@@ -136,7 +143,6 @@ def main():
parser.add_argument("--cpu", type=int, default=8, dest="num_workers", help="number of workers")
parser.add_argument("--use_gpu", type=bool, default=True, dest="use_gpu", help="whether to use gpu")
- parser.add_argument("--ckpt", type=str, dest="ckpt_path", help="checkpoint path")
args = parser.parse_args()
cfg = vars(args)
diff --git a/pathology/hovernet/inference.py b/pathology/hovernet/inference.py
new file mode 100644
index 0000000000..5af7ef9d82
--- /dev/null
+++ b/pathology/hovernet/inference.py
@@ -0,0 +1,213 @@
+import logging
+import os
+import time
+from argparse import ArgumentParser
+from glob import glob
+
+import torch
+import torch.distributed as dist
+
+from monai.apps.pathology.inferers import SlidingWindowHoVerNetInferer
+from monai.apps.pathology.transforms import (
+ HoVerNetInstanceMapPostProcessingd,
+ HoVerNetNuclearTypePostProcessingd,
+)
+from monai.data import DataLoader, Dataset, PILReader, partition_dataset
+from monai.engines import SupervisedEvaluator
+from monai.networks.nets import HoVerNet
+from monai.transforms import (
+ CastToTyped,
+ Compose,
+ EnsureChannelFirstd,
+ FromMetaTensord,
+ LoadImaged,
+ FlattenSubKeysd,
+ SaveImaged,
+ ScaleIntensityRanged,
+)
+from monai.utils import HoVerNetBranch, first
+
+
+def create_output_dir(cfg):
+ output_dir = cfg["output"]
+ print(f"Outputs are saved at '{output_dir}'.")
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ return output_dir
+
+
+def run(cfg):
+ # --------------------------------------------------------------------------
+ # Set Directory, Device,
+ # --------------------------------------------------------------------------
+ output_dir = create_output_dir(cfg)
+ multi_gpu = cfg["use_gpu"] if torch.cuda.device_count() > 1 else False
+ if multi_gpu:
+ dist.init_process_group(backend="nccl", init_method="env://")
+ device = torch.device("cuda:{}".format(dist.get_rank()))
+ torch.cuda.set_device(device)
+ if dist.get_rank() == 0:
+ print(f"Running multi-gpu with {dist.get_world_size()} GPUs")
+ else:
+ device = torch.device("cuda" if cfg["use_gpu"] and torch.cuda.is_available() else "cpu")
+ # --------------------------------------------------------------------------
+ # Transforms
+ # --------------------------------------------------------------------------
+ # Preprocessing transforms
+ pre_transforms = Compose(
+ [
+ LoadImaged(keys="image", reader=PILReader, converter=lambda x: x.convert("RGB")),
+ EnsureChannelFirstd(keys="image"),
+ CastToTyped(keys="image", dtype=torch.float32),
+ ScaleIntensityRanged(keys="image", a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
+ ]
+ )
+ # Postprocessing transforms
+ post_transforms = Compose(
+ [
+ FlattenSubKeysd(
+ keys="pred",
+ sub_keys=[HoVerNetBranch.NC.value, HoVerNetBranch.NP.value, HoVerNetBranch.HV.value],
+ delete_keys=True,
+ ),
+ HoVerNetInstanceMapPostProcessingd(sobel_kernel_size=3, marker_threshold=0.7, marker_radius=2),
+ HoVerNetNuclearTypePostProcessingd(),
+ FromMetaTensord(keys=["image"]),
+ SaveImaged(
+ keys="instance_map",
+ meta_keys="image_meta_dict",
+ output_ext="png",
+ output_dir=output_dir,
+ output_postfix="instance_map",
+ output_dtype="uint8",
+ separate_folder=False,
+ ),
+ SaveImaged(
+ keys="type_map",
+ meta_keys="image_meta_dict",
+ output_ext="png",
+ output_dir=output_dir,
+ output_postfix="type_map",
+ output_dtype="uint8",
+ separate_folder=False,
+ ),
+ ]
+ )
+ # --------------------------------------------------------------------------
+ # Data and Data Loading
+ # --------------------------------------------------------------------------
+ # List of whole slide images
+ data_list = [{"image": image} for image in glob(os.path.join(cfg["root"], "*.png"))]
+
+ if multi_gpu:
+ data = partition_dataset(data=data_list, num_partitions=dist.get_world_size())[dist.get_rank()]
+ else:
+ data = data_list
+
+ # Dataset
+ dataset = Dataset(data, transform=pre_transforms)
+
+ # Dataloader
+ data_loader = DataLoader(dataset, num_workers=cfg["ncpu"], batch_size=cfg["batch_size"], pin_memory=True)
+
+ # --------------------------------------------------------------------------
+ # Run some sanity checks
+ # --------------------------------------------------------------------------
+ # Check first sample
+ first_sample = first(data_loader)
+ if first_sample is None:
+ raise ValueError("First sample is None!")
+ print("image: ")
+ print(" shape", first_sample["image"].shape)
+ print(" type: ", type(first_sample["image"]))
+ print(" dtype: ", first_sample["image"].dtype)
+ print(f"batch size: {cfg['batch_size']}")
+ print(f"number of batches: {len(data_loader)}")
+
+ # --------------------------------------------------------------------------
+ # Model
+ # --------------------------------------------------------------------------
+ # Create model and load weights
+ model = HoVerNet(
+ mode=cfg["mode"],
+ in_channels=3,
+ out_classes=cfg["out_classes"],
+ ).to(device)
+ model.load_state_dict(torch.load(cfg["ckpt"], map_location=device)["net"])
+ model.eval()
+ if multi_gpu:
+ model = torch.nn.parallel.DistributedDataParallel(
+ model, device_ids=[dist.get_rank()], output_device=dist.get_rank()
+ )
+
+ # --------------------------------------------
+ # Inference
+ # --------------------------------------------
+ # Inference engine
+ sliding_inferer = SlidingWindowHoVerNetInferer(
+ roi_size=cfg["patch_size"],
+ sw_batch_size=cfg["sw_batch_size"],
+ overlap=1.0 - float(cfg["out_size"]) / float(cfg["patch_size"]),
+ padding_mode="constant",
+ cval=0,
+ sw_device=device,
+ progress=True,
+ extra_input_padding=((cfg["patch_size"] - cfg["out_size"]) // 2,) * 4,
+ )
+
+ evaluator = SupervisedEvaluator(
+ device=device,
+ val_data_loader=data_loader,
+ network=model,
+ postprocessing=post_transforms,
+ inferer=sliding_inferer,
+ amp=cfg["use_amp"],
+ )
+ evaluator.run()
+
+ if multi_gpu:
+ dist.destroy_process_group()
+
+
+def main():
+ logging.basicConfig(level=logging.INFO)
+
+ parser = ArgumentParser(description="Tumor detection on whole slide pathology images.")
+ parser.add_argument(
+ "--root",
+ type=str,
+ default="/workspace/Data/Pathology/CoNSeP/Test/Images",
+ help="Images root dir",
+ )
+ parser.add_argument("--output", type=str, default="./eval/", dest="output", help="log directory")
+ parser.add_argument(
+ "--ckpt",
+ type=str,
+ default="./logs/model.pt",
+ help="Path to the pytorch checkpoint",
+ )
+ parser.add_argument("--mode", type=str, default="original", help="HoVerNet mode (original/fast)")
+ parser.add_argument("--out-classes", type=int, default=5, help="number of output classes")
+ parser.add_argument("--bs", type=int, default=1, dest="batch_size", help="batch size")
+ parser.add_argument("--swbs", type=int, default=8, dest="sw_batch_size", help="sliding window batch size")
+ parser.add_argument("--no-amp", action="store_false", dest="use_amp", help="deactivate use of amp")
+ parser.add_argument("--no-gpu", action="store_false", dest="use_gpu", help="deactivate use of gpu")
+ parser.add_argument("--ncpu", type=int, default=0, help="number of CPU workers")
+ args = parser.parse_args()
+
+ config_dict = vars(args)
+ if config_dict["mode"].lower() == "original":
+ config_dict["patch_size"] = 270
+ config_dict["out_size"] = 80
+ elif config_dict["mode"].lower() == "fast":
+ config_dict["patch_size"] = 256
+ config_dict["out_size"] = 164
+ else:
+ raise ValueError("`--mode` should be either `original` or `fast`.")
+
+ print(config_dict)
+ run(config_dict)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pathology/hovernet/prepare_patches.py b/pathology/hovernet/prepare_patches.py
index bf8e90f9a4..73deca8973 100644
--- a/pathology/hovernet/prepare_patches.py
+++ b/pathology/hovernet/prepare_patches.py
@@ -1,14 +1,14 @@
-import os
-import math
-import tqdm
import glob
-import shutil
+import math
+import os
import pathlib
+import shutil
+from argparse import ArgumentParser
import numpy as np
import scipy.io as sio
+import tqdm
from PIL import Image
-from argparse import ArgumentParser
def load_img(path):
@@ -30,7 +30,7 @@ def load_ann(path):
return ann
-class PatchExtractor():
+class PatchExtractor:
"""Extractor to generate patches with or without padding.
Turn on debug mode to see how it is done.
@@ -42,9 +42,9 @@ class PatchExtractor():
a list of sub patches, each patch has dtype same as x
Examples:
- >>> xtractor = PatchExtractor((450, 450), (120, 120))
+ >>> extractor = PatchExtractor((450, 450), (120, 120))
>>> img = np.full([1200, 1200, 3], 255, np.uint8)
- >>> patches = xtractor.extract(img, 'mirror')
+ >>> patches = extractor.extract(img, 'mirror')
"""
@@ -166,9 +166,7 @@ def main(cfg):
os.makedirs(out_dir)
pbar_format = "Process File: |{bar}| {n_fmt}/{total_fmt}[{elapsed}<{remaining},{rate_fmt}]"
- pbarx = tqdm.tqdm(
- total=len(file_list), bar_format=pbar_format, ascii=True, position=0
- )
+ pbarx = tqdm.tqdm(total=len(file_list), bar_format=pbar_format, ascii=True, position=0)
for file_path in file_list:
base_name = pathlib.Path(file_path).stem
@@ -210,12 +208,12 @@ def parse_arguments():
parser.add_argument(
"--root",
type=str,
- default="/home/yunliu/Workspace/Data/CoNSeP",
+ default="/workspace/Data/Pathology/CoNSeP",
help="root path to image folder containing training/test",
)
parser.add_argument("--type", type=str, default="mirror", dest="extract_type", help="Choose 'mirror' or 'valid'")
- parser.add_argument("--ps", nargs='+', type=int, default=[540, 540], dest="patch_size", help="patch size")
- parser.add_argument("--ss", nargs='+', type=int, default=[164, 164], dest="step_size", help="patch size")
+ parser.add_argument("--ps", nargs="+", type=int, default=[540, 540], dest="patch_size", help="patch size")
+ parser.add_argument("--ss", nargs="+", type=int, default=[164, 164], dest="step_size", help="patch size")
args = parser.parse_args()
config_dict = vars(args)
diff --git a/pathology/hovernet/training.py b/pathology/hovernet/training.py
index d402dc5e0b..6c79e9f74c 100644
--- a/pathology/hovernet/training.py
+++ b/pathology/hovernet/training.py
@@ -45,32 +45,28 @@
def create_log_dir(cfg):
- timestamp = time.strftime("%y%m%d-%H%M")
- run_folder_name = (
- f"{timestamp}_hovernet_bs{cfg['batch_size']}_ep{cfg['n_epochs']}_lr{cfg['lr']}_seed{cfg['seed']}_stage{cfg['stage']}"
- )
- log_dir = os.path.join(cfg["logdir"], run_folder_name)
- print(f"Logs and model are saved at '{log_dir}'.")
+ log_dir = cfg["log_dir"]
+ if cfg["stage"] == 0:
+ log_dir = os.path.join(log_dir, "stage0")
+ print(f"Logs and models are saved at '{log_dir}'.")
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
return log_dir
def prepare_data(data_dir, phase):
- # prepare datalist
- images = sorted(
- glob.glob(os.path.join(data_dir, f"{phase}/*image.npy")))
- inst_maps = sorted(
- glob.glob(os.path.join(data_dir, f"{phase}/*inst_map.npy")))
- type_maps = sorted(
- glob.glob(os.path.join(data_dir, f"{phase}/*type_map.npy")))
-
- data_dicts = [
+ """prepare data list"""
+
+ data_dir = os.path.join(data_dir, phase)
+ images = sorted(glob.glob(os.path.join(data_dir, "*image.npy")))
+ inst_maps = sorted(glob.glob(os.path.join(data_dir, "*inst_map.npy")))
+ type_maps = sorted(glob.glob(os.path.join(data_dir, "*type_map.npy")))
+
+ data_list = [
{"image": _image, "label_inst": _inst_map, "label_type": _type_map}
for _image, _inst_map, _type_map in zip(images, inst_maps, type_maps)
]
-
- return data_dicts
+ return data_list
def get_loaders(cfg, train_transforms, val_transforms):
@@ -104,13 +100,10 @@ def get_loaders(cfg, train_transforms, val_transforms):
batch_size=cfg["batch_size"],
num_workers=cfg["num_workers"],
shuffle=True,
- pin_memory=torch.cuda.is_available()
+ pin_memory=torch.cuda.is_available(),
)
val_loader = DataLoader(
- valid_ds,
- batch_size=cfg["batch_size"],
- num_workers=cfg["num_workers"],
- pin_memory=torch.cuda.is_available()
+ valid_ds, batch_size=cfg["batch_size"], num_workers=cfg["num_workers"], pin_memory=torch.cuda.is_available()
)
return train_loader, val_loader
@@ -144,7 +137,7 @@ def create_model(cfg, device):
pretrained_url=None,
freeze_encoder=False,
).to(device)
- model.load_state_dict(torch.load(cfg["ckpt_path"])['net'])
+ model.load_state_dict(torch.load(cfg["ckpt"])["net"])
print(f'stage{cfg["stage"]}, success load weight!')
return model
@@ -172,7 +165,6 @@ def run(log_dir, cfg):
# Data Loading and Preprocessing
# --------------------------------------------------------------------------
# __________________________________________________________________________
- # __________________________________________________________________________
# Build MONAI preprocessing
train_transforms = Compose(
[
@@ -194,11 +186,13 @@ def run(log_dir, cfg):
),
RandFlipd(keys=["image", "label_inst", "label_type"], prob=0.5, spatial_axis=0),
RandFlipd(keys=["image", "label_inst", "label_type"], prob=0.5, spatial_axis=1),
- OneOf(transforms=[
- RandGaussianSmoothd(keys=["image"], sigma_x=(0.1, 1.1), sigma_y=(0.1, 1.1), prob=1.0),
- MedianSmoothd(keys=["image"], radius=1),
- RandGaussianNoised(keys=["image"], prob=1.0, std=0.05)
- ]),
+ OneOf(
+ transforms=[
+ RandGaussianSmoothd(keys=["image"], sigma_x=(0.1, 1.1), sigma_y=(0.1, 1.1), prob=1.0),
+ MedianSmoothd(keys=["image"], radius=1),
+ RandGaussianNoised(keys=["image"], prob=1.0, std=0.05),
+ ]
+ ),
CastToTyped(keys="image", dtype=np.uint8),
TorchVisiond(
keys=["image"],
@@ -206,7 +200,7 @@ def run(log_dir, cfg):
brightness=(229 / 255.0, 281 / 255.0),
contrast=(0.95, 1.10),
saturation=(0.8, 1.2),
- hue=(-0.04, 0.04)
+ hue=(-0.04, 0.04),
),
AsDiscreted(keys=["label_type"], to_onehot=[5]),
ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
@@ -258,10 +252,12 @@ def run(log_dir, cfg):
loss_function = HoVerNetLoss(lambda_hv_mse=1.0)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg["lr"], weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25)
- post_process_np = Compose([
- Activationsd(keys=HoVerNetBranch.NP.value, softmax=True),
- AsDiscreted(keys=HoVerNetBranch.NP.value, argmax=True),
- ])
+ post_process_np = Compose(
+ [
+ Activationsd(keys=HoVerNetBranch.NP.value, softmax=True),
+ AsDiscreted(keys=HoVerNetBranch.NP.value, argmax=True),
+ ]
+ )
post_process = Lambdad(keys="pred", func=post_process_np)
# --------------------------------------------
@@ -282,10 +278,15 @@ def run(log_dir, cfg):
evaluator = SupervisedEvaluator(
device=device,
val_data_loader=val_loader,
- prepare_batch=PrepareBatchHoVerNet(extra_keys=['label_type', 'hover_label_inst']),
+ prepare_batch=PrepareBatchHoVerNet(extra_keys=["label_type", "hover_label_inst"]),
network=model,
postprocessing=post_process,
- key_val_metric={"val_dice": MeanDice(include_background=False, output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value))},
+ key_val_metric={
+ "val_dice": MeanDice(
+ include_background=False,
+ output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value),
+ )
+ },
val_handlers=val_handlers,
amp=cfg["amp"],
)
@@ -298,6 +299,8 @@ def run(log_dir, cfg):
save_dir=log_dir,
save_dict={"net": model, "opt": optimizer},
save_interval=cfg["save_interval"],
+ save_final=True,
+ final_filename="model.pt",
epoch_level=True,
),
StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)),
@@ -311,12 +314,17 @@ def run(log_dir, cfg):
device=device,
max_epochs=cfg["n_epochs"],
train_data_loader=train_loader,
- prepare_batch=PrepareBatchHoVerNet(extra_keys=['label_type', 'hover_label_inst']),
+ prepare_batch=PrepareBatchHoVerNet(extra_keys=["label_type", "hover_label_inst"]),
network=model,
optimizer=optimizer,
loss_function=loss_function,
postprocessing=post_process,
- key_train_metric={"train_dice": MeanDice(include_background=False, output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value))},
+ key_train_metric={
+ "train_dice": MeanDice(
+ include_background=False,
+ output_transform=from_engine_hovernet(keys=["pred", "label"], nested_key=HoVerNetBranch.NP.value),
+ )
+ },
train_handlers=train_handlers,
amp=cfg["amp"],
)
@@ -331,18 +339,18 @@ def main():
parser.add_argument(
"--root",
type=str,
- default="/workspace/Data/CoNSeP/Prepared",
+ default="/workspace/Data/Pathology/CoNSeP/Prepared",
help="root data dir",
)
- parser.add_argument("--logdir", type=str, default="./logs/", dest="logdir", help="log directory")
+ parser.add_argument("--log-dir", type=str, default="./logs/", help="log directory")
parser.add_argument("-s", "--seed", type=int, default=24)
parser.add_argument("--bs", type=int, default=16, dest="batch_size", help="batch size")
- parser.add_argument("--ep", type=int, default=3, dest="n_epochs", help="number of epochs")
+ parser.add_argument("--ep", type=int, default=50, dest="n_epochs", help="number of epochs")
parser.add_argument("--lr", type=float, default=1e-4, dest="lr", help="initial learning rate")
parser.add_argument("--step", type=int, default=25, dest="step_size", help="period of learning rate decay")
parser.add_argument("-f", "--val_freq", type=int, default=1, help="validation frequence")
- parser.add_argument("--stage", type=int, default=0, dest="stage", help="training stage")
+ parser.add_argument("--stage", type=int, default=0, help="training stage")
parser.add_argument("--no-amp", action="store_false", dest="amp", help="deactivate amp")
parser.add_argument("--classes", type=int, default=5, dest="out_classes", help="output classes")
parser.add_argument("--mode", type=str, default="original", help="choose either `original` or `fast`")
@@ -350,10 +358,12 @@ def main():
parser.add_argument("--save_interval", type=int, default=10)
parser.add_argument("--cpu", type=int, default=8, dest="num_workers", help="number of workers")
parser.add_argument("--no-gpu", action="store_false", dest="use_gpu", help="deactivate use of gpu")
- parser.add_argument("--ckpt", type=str, dest="ckpt_path", help="checkpoint path")
+ parser.add_argument("--ckpt", type=str, dest="ckpt", help="model checkpoint path")
args = parser.parse_args()
cfg = vars(args)
+ if cfg["stage"] == 1 and not cfg["ckpt"] and cfg["log_dir"]:
+ cfg["ckpt"] = os.path.join(cfg["log_dir"], "stage0", "model.pt")
print(cfg)
logging.basicConfig(level=logging.INFO)