diff --git a/README.md b/README.md index f28ce6605c..ae19955f0e 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,10 @@ The example is a PyTorch Ignite program and shows several key features of MONAI, #### [COVID 19-20 challenge baseline](./3d_segmentation/challenge_baseline) This folder provides a simple baseline method for training, validation, and inference for [COVID-19 LUNG CT LESION SEGMENTATION CHALLENGE - 2020](https://covid-segmentation.grand-challenge.org/COVID-19-20/) (a MICCAI Endorsed Event). +**deepgrow** +#### [Deepgrow](./deepgrow) +The example show how to train/validate a 2D/3D deepgrow model. It also demonstrates running an inference for trained deepgrow models. + **federated learning** #### [Substra](./federated_learning/substra) The example show how to execute the 3d segmentation torch tutorial on a federated learning platform, Substra. diff --git a/deepgrow/ignite/README.md b/deepgrow/ignite/README.md new file mode 100644 index 0000000000..398f44f39d --- /dev/null +++ b/deepgrow/ignite/README.md @@ -0,0 +1,142 @@ +# Deepgrow Examples +This folder contains examples to run train and validate a deepgrow 2D/3D model. +It also has notebooks to run inference over trained model. + +### 1. Data + +Training a deepgrow model requires data. Some public available datasets which are used in the examples can be downloaded from [Medical Segmentation Decathlon](https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2) or [Synapse](https://www.synapse.org/#!Synapse:syn3193805/wiki/217789). + +### 2. Questions and bugs + +- For questions relating to the use of MONAI, please us our [Discussions tab](https://github.com/Project-MONAI/MONAI/discussions) on the main repository of MONAI. +- For bugs relating to MONAI functionality, please create an issue on the [main repository](https://github.com/Project-MONAI/MONAI/issues). +- For bugs relating to the running of a tutorial, please create an issue in [this repository](https://github.com/Project-MONAI/Tutorials/issues). + +### 3. List of notebooks and examples +#### [Prepare Your Data](./prepare_dataset.py) +This example is a standard PyTorch program and helps user to prepare training input for 2D or 3D. + +```bash +# Run to know all possible options +python ./prepare_dataset.py -h + +# Prepare dataset to train a 2D Deepgrow model +python ./prepare_dataset.py + --dimensions 2 \ + --dataset_root MSD_Task09_Spleen \ + --dataset_json MSD_Task09_Spleen/dataset.json \ + --output deepgrow/2D/MSD_Task09_Spleen + +# Prepare dataset to train a 3D Deepgrow model +python ./prepare_dataset.py + --dimensions 3 \ + --dataset_root MSD_Task09_Spleen \ + --dataset_json MSD_Task09_Spleen/dataset.json \ + --output deepgrow/3D/MSD_Task09_Spleen +``` + +#### [Deepgrow 2D Training](./train.py) +This example is a standard PyTorch program and helps user to run training over pre-processed dataset for 2D. +```bash +# Run to know all possible options +python ./train.py -h + +# Train a 2D Deepgrow model on single-gpu +python ./train.py + --input deepgrow/2D/MSD_Task09_Spleen/dataset.json \ + --output models/2D \ + --epochs 50 + +# Train a 2D Deepgrow model on multi-gpu (NVIDIA) +python -m torch.distributed.launch \ + --nproc_per_node=`nvidia-smi -L | wc -l` \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr="localhost" \ + --master_port=1234 \ + -m train \ + --multi_gpu true \ + --input deepgrow/2D/MSD_Task09_Spleen/dataset.json \ + --output models/2D \ + --epochs 50 + +# After training to export/save as torch script model +python ./train.py + --input models/2D/model.pt \ + --output models/2D/model.ts \ + --export true +``` + +#### [Deepgrow 2D Validation](./validate.py) +This example is a standard PyTorch program and helps user to run evaluation for a trained 2D model. +```bash +# Run to know all possible options +python ./validate.py -h + +# Evaluate a 2D Deepgrow model +python ./validate.py + --input deepgrow/2D/MSD_Task09_Spleen/dataset.json \ + --output eval/2D \ + --model_path models/2D/model.pt +``` + +#### [Deepgrow 2D Inference](./inference.ipynb) +This notebook helps to run pre-transforms before running inference over a Deepgrow 2D model. +It also helps to run post-transforms to get the final label mask. + + +#### [Deepgrow 3D Training](./train_3d.py) +This is an extension for [train.py](./train.py) that redefines basic default arguments to run 3D training. +```bash +# Run to know all possible options +python ./train_3d.py -h + +# Train a 3D Deepgrow model on single-gpu +python ./train_3d.py + --input deepgrow/3D/MSD_Task09_Spleen/dataset.json \ + --output models/3D \ + --epochs 100 + +# Train a 3D Deepgrow model on multi-gpu (NVIDIA) +python -m torch.distributed.launch \ + --nproc_per_node=`nvidia-smi -L | wc -l` \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr="localhost" \ + --master_port=1234 \ + -m train_3d \ + --multi_gpu true \ + --input deepgrow/3D/MSD_Task09_Spleen/dataset.json \ + --output models/3D \ + --epochs 100 + +# After training to export/save as torch script model +python ./train_3d.py + --input models/3D/model.pt \ + --output models/3D/model.ts \ + --export true +``` + +#### [Deepgrow 3D Validation](./validate_3d.py) +This is an extension for [validate.py](./validate.py) that redefines basic default arguments to run 3D validation. +```bash +# Run to know all possible options +python ./validate_3d.py -h + +# Evaluate a 3D Deepgrow model +python ./validate_3d.py + --input deepgrow/3D/MSD_Task09_Spleen/dataset.json \ + --output eval/3D \ + --model_path models/3D/model.pt +``` + +#### [Deepgrow 3D Inference](./inference_3d.ipynb) +This notebook helps to run any pre-transforms before running inference over a Deepgrow 3D model. +It also helps to run post-transforms to get the final label mask. + + +#### [Deepgrow Stats](./handler.py) +It contains basic ignite handler to capture region/organ-wise statistics, save snapshots, outputs while running train/validation over a dataset that has multi-label mask. +By-default the handler is added as part of training/validation steps. + +![snashot](./stats.png) diff --git a/deepgrow/ignite/__init__.py b/deepgrow/ignite/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/deepgrow/ignite/handler.py b/deepgrow/ignite/handler.py new file mode 100644 index 0000000000..60c6bc5966 --- /dev/null +++ b/deepgrow/ignite/handler.py @@ -0,0 +1,284 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import statistics + +import numpy as np +import torch +import torch.distributed + +from monai.engines.workflow import Engine, Events +from monai.handlers.tensorboard_handlers import SummaryWriter +from monai.metrics import compute_meandice +from monai.transforms import rescale_array +from monai.utils import optional_import +from monai.visualize import plot_2d_or_3d_image + +nib, _ = optional_import("nibabel") +torchvision, _ = optional_import("torchvision") +make_grid, _ = optional_import("torchvision.utils", name="make_grid") +Image, _ = optional_import("PIL.Image") +ImageDraw, _ = optional_import("PIL.ImageDraw") + + +class RegionDice: + def __init__(self): + self.data = [] + + def reset(self): + self.data = [] + + def update(self, y_pred, y, batched=True): + if not batched: + y_pred = y_pred[None] + y = y[None] + score = compute_meandice(y_pred=y_pred, y=y, include_background=False).mean() + self.data.append(score.item()) + + def mean(self): + return statistics.mean(self.data) + + def stdev(self): + return statistics.stdev(self.data) if len(self.data) > 1 else 0 + + +class DeepgrowStatsHandler: + def __init__( + self, + summary_writer=None, + interval=1, + log_dir="./runs", + tag_name="val_dice", + compute_metric=True, + images=True, + image_interval=1, + max_channels=1, + max_frames=64, + add_scalar=True, + merge_scalar=False, + fold_size=0, + ): + self.writer = SummaryWriter(log_dir=log_dir) if summary_writer is None else summary_writer + self.interval = interval + self.tag_name = tag_name + self.compute_metric = compute_metric + self.images = images + self.image_interval = image_interval + self.max_channels = max_channels + self.max_frames = max_frames + self.add_scalar = add_scalar + self.merge_scalar = merge_scalar + self.fold_size = fold_size + self.logger = logging.getLogger(__name__) + + if torch.distributed.is_initialized(): + self.tag_name = "{}-r{}".format(self.tag_name, torch.distributed.get_rank()) + + self.plot_data = {} + self.metric_data = {} + + def attach(self, engine: Engine) -> None: + engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.interval), self, "iteration") + engine.add_event_handler(Events.EPOCH_COMPLETED(every=1), self, "epoch") + + def write_images(self, epoch): + if not self.plot_data or not len(self.plot_data): + return + + all_imgs = [] + for region in sorted(self.plot_data.keys()): + metric = self.metric_data.get(region) + region_data = self.plot_data[region] + if len(region_data[0].shape) == 3: + ti = Image.new("RGB", region_data[0].shape[1:]) + d = ImageDraw.Draw(ti) + t = "region: {}".format(region) + if self.compute_metric: + t = t + "\ndice: {:.4f}".format(metric.mean()) + t = t + "\nstdev: {:.4f}".format(metric.stdev()) + d.multiline_text((10, 10), t, fill=(255, 255, 0)) + ti = rescale_array(np.rollaxis(np.array(ti), 2, 0)[0][np.newaxis]) + all_imgs.append(ti) + all_imgs.extend(region_data) + + if len(all_imgs[0].shape) == 3: + img_tensor = make_grid(tensor=torch.from_numpy(np.array(all_imgs)), nrow=4, normalize=True, pad_value=2) + self.writer.add_image(tag=f"Deepgrow Regions ({self.tag_name})", img_tensor=img_tensor, global_step=epoch) + + if len(all_imgs[0].shape) == 4: + for region in sorted(self.plot_data.keys()): + tags = [f"region_{region}_image", f"region_{region}_label", f"region_{region}_output"] + if torch.distributed.is_initialized(): + rank = "r{}-".format(torch.distributed.get_rank()) + tags = [rank + tags[0], rank + tags[1], rank + tags[2]] + for i in range(3): + img = self.plot_data[region][i] + img = np.moveaxis(img, -3, -1) + plot_2d_or_3d_image( + img[np.newaxis], epoch, self.writer, 0, self.max_channels, self.max_frames, tags[i] + ) + + self.logger.info( + "Saved {} Regions {} into Tensorboard at epoch: {}".format( + len(self.plot_data), sorted([*self.plot_data]), epoch + ) + ) + self.writer.flush() + + def write_region_metrics(self, epoch): + metric_sum = 0 + means = {} + for region in self.metric_data: + metric = self.metric_data[region].mean() + self.logger.info( + "Epoch[{}] Metrics -- Region: {:0>2d}, {}: {:.4f}".format(epoch, region, self.tag_name, metric) + ) + + if self.merge_scalar: + means["{:0>2d}".format(region)] = metric + else: + self.writer.add_scalar("{}_{:0>2d}".format(self.tag_name, region), metric, epoch) + metric_sum += metric + + if self.merge_scalar: + means["avg"] = metric_sum / len(self.metric_data) + self.writer.add_scalars("{}_region".format(self.tag_name), means, epoch) + elif len(self.metric_data) > 1: + metric_avg = metric_sum / len(self.metric_data) + self.writer.add_scalar("{}_regions_avg".format(self.tag_name), metric_avg, epoch) + self.writer.flush() + + def __call__(self, engine: Engine, action) -> None: + total_steps = engine.state.iteration + if total_steps < engine.state.epoch_length: + total_steps = engine.state.epoch_length * (engine.state.epoch - 1) + total_steps + + if action == "epoch" and not self.fold_size: + epoch = engine.state.epoch + elif self.fold_size and total_steps % self.fold_size == 0: + epoch = int(total_steps / self.fold_size) + else: + epoch = None + + if epoch: + if self.images and epoch % self.image_interval == 0: + self.write_images(epoch) + if self.add_scalar: + self.write_region_metrics(epoch) + + if action == "epoch" or epoch: + self.plot_data = {} + self.metric_data = {} + return + + device = engine.state.device + batch_data = engine.state.batch + output_data = engine.state.output + + for bidx in range(len(batch_data.get("region", []))): + region = batch_data.get("region")[bidx] + region = region.item() if torch.is_tensor(region) else region + + if self.images and self.plot_data.get(region) is None: + self.plot_data[region] = [ + rescale_array(batch_data["image"][bidx][0].detach().cpu().numpy()[np.newaxis], 0, 1), + rescale_array(batch_data["label"][bidx].detach().cpu().numpy(), 0, 1), + rescale_array(output_data["pred"][bidx].detach().cpu().numpy(), 0, 1), + ] + + if self.compute_metric: + if self.metric_data.get(region) is None: + self.metric_data[region] = RegionDice() + self.metric_data[region].update( + y_pred=output_data["pred"][bidx].to(device), y=batch_data["label"][bidx].to(device), batched=False + ) + + +class SegmentationSaver: + def __init__( + self, + output_dir: str = "./runs", + save_np=False, + images=True, + ): + self.output_dir = output_dir + self.save_np = save_np + self.images = images + os.makedirs(self.output_dir, exist_ok=True) + + def attach(self, engine: Engine) -> None: + if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): + engine.add_event_handler(Events.ITERATION_COMPLETED, self) + + def __call__(self, engine: Engine): + batch_data = engine.state.batch + output_data = engine.state.output + device = engine.state.device + tag = "" + if torch.distributed.is_initialized(): + tag = "r{}-".format(torch.distributed.get_rank()) + + for bidx in range(len(batch_data.get("image"))): + step = engine.state.iteration + region = batch_data.get("region")[bidx] + region = region.item() if torch.is_tensor(region) else region + + image = batch_data["image"][bidx][0].detach().cpu().numpy()[np.newaxis] + label = batch_data["label"][bidx].detach().cpu().numpy() + pred = output_data["pred"][bidx].detach().cpu().numpy() + dice = compute_meandice( + y_pred=output_data["pred"][bidx][None].to(device), + y=batch_data["label"][bidx][None].to(device), + include_background=False, + ).mean() + + if self.save_np: + np.savez( + os.path.join( + self.output_dir, + "{}img_label_pred_{}_{:0>4d}_{:0>2d}_{:.4f}".format(tag, region, step, bidx, dice), + ), + image, + label, + pred, + ) + + if self.images and len(image.shape) == 3: + img = make_grid(torch.from_numpy(rescale_array(image, 0, 1)[0])) + lab = make_grid(torch.from_numpy(rescale_array(label, 0, 1)[0])) + + pos = rescale_array(output_data["image"][bidx][1].detach().cpu().numpy()[np.newaxis], 0, 1)[0] + neg = rescale_array(output_data["image"][bidx][2].detach().cpu().numpy()[np.newaxis], 0, 1)[0] + pre = make_grid(torch.from_numpy(np.array([rescale_array(pred, 0, 1)[0], pos, neg]))) + + torchvision.utils.save_image( + tensor=[img, lab, pre], + nrow=3, + pad_value=2, + fp=os.path.join( + self.output_dir, + "{}img_label_pred_{}_{:0>4d}_{:0>2d}_{:.4f}.png".format(tag, region, step, bidx, dice), + ), + ) + + if self.images and len(image.shape) == 4: + samples = {"image": image[0], "label": label[0], "pred": pred[0]} + for sample in samples: + img = np.moveaxis(samples[sample], -3, -1) + img = nib.Nifti1Image(img, np.eye(4)) + nib.save( + img, + os.path.join( + self.output_dir, "{}{}_{:0>4d}_{:0>2d}_{:.4f}.nii.gz".format(tag, sample, step, bidx, dice) + ), + ) diff --git a/deepgrow/ignite/inference.ipynb b/deepgrow/ignite/inference.ipynb new file mode 100644 index 0000000000..385c750556 --- /dev/null +++ b/deepgrow/ignite/inference.ipynb @@ -0,0 +1,227 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "from torch import jit\n", + "\n", + "from monai.apps.deepgrow.transforms import (\n", + " AddGuidanceFromPointsd,\n", + " AddGuidanceSignald,\n", + " Fetch2DSliced,\n", + " ResizeGuidanced,\n", + " RestoreLabeld,\n", + " SpatialCropGuidanced,\n", + ")\n", + "from monai.transforms import (\n", + " AsChannelFirstd,\n", + " Spacingd,\n", + " LoadImaged,\n", + " AddChanneld,\n", + " NormalizeIntensityd,\n", + " ToTensord,\n", + " ToNumpyd,\n", + " Activationsd,\n", + " AsDiscreted,\n", + " Resized\n", + ")\n", + "\n", + "max_epochs = 1\n", + "\n", + "\n", + "def draw_points(guidance):\n", + " if guidance is None:\n", + " return\n", + " colors = ['r+', 'b+']\n", + " for color, points in zip(colors, guidance):\n", + " for p in points:\n", + " p1 = p[-1]\n", + " p2 = p[-2]\n", + " plt.plot(p1, p2, color, 'MarkerSize', 30)\n", + "\n", + "\n", + "def show_image(image, label, guidance=None):\n", + " plt.figure(\"check\", (12, 6))\n", + " plt.subplot(1, 2, 1)\n", + " plt.title(\"image\")\n", + " plt.imshow(image, cmap=\"gray\")\n", + "\n", + " if label is not None:\n", + " masked = np.ma.masked_where(label == 0, label)\n", + " plt.imshow(masked, 'jet', interpolation='none', alpha=0.7)\n", + "\n", + " draw_points(guidance)\n", + " plt.colorbar()\n", + "\n", + " if label is not None:\n", + " plt.subplot(1, 2, 2)\n", + " plt.title(\"label\")\n", + " plt.imshow(label)\n", + " plt.colorbar()\n", + " # draw_points(guidance)\n", + " plt.show()\n", + "\n", + "\n", + "def print_data(data):\n", + " for k in data:\n", + " v = data[k]\n", + "\n", + " d = type(v)\n", + " if type(v) in (int, float, bool, str, dict, tuple):\n", + " d = v\n", + " elif hasattr(v, 'shape'):\n", + " d = v.shape\n", + "\n", + " if k in ('image_meta_dict', 'label_meta_dict'):\n", + " for m in data[k]:\n", + " print('{} Meta:: {} => {}'.format(k, m, data[k][m]))\n", + " else:\n", + " print('Data key: {} = {}'.format(k, d))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Pre Processing\n", + "roi_size = [256, 256]\n", + "model_size = [128, 192, 192]\n", + "pixdim = (1.0, 1.0)\n", + "dimensions = 2\n", + "\n", + "data = {\n", + " 'image': '_image.nii.gz',\n", + " 'foreground': [[66, 180, 105]],\n", + " 'background': []\n", + "}\n", + "slice_idx = original_slice_idx = data['foreground'][0][2]\n", + "\n", + "pre_transforms = [\n", + " LoadImaged(keys='image'),\n", + " AsChannelFirstd(keys='image'),\n", + " Spacingd(keys='image', pixdim=pixdim, mode='bilinear'),\n", + "\n", + " AddGuidanceFromPointsd(ref_image='image', guidance='guidance', foreground='foreground', background='background',\n", + " dimensions=dimensions),\n", + " Fetch2DSliced(keys='image', guidance='guidance'),\n", + " AddChanneld(keys='image'),\n", + "\n", + " SpatialCropGuidanced(keys='image', guidance='guidance', spatial_size=roi_size),\n", + " Resized(keys='image', spatial_size=roi_size, mode='area'),\n", + " ResizeGuidanced(guidance='guidance', ref_image='image'),\n", + " NormalizeIntensityd(keys='image', subtrahend=208.0, divisor=388.0),\n", + " AddGuidanceSignald(image='image', guidance='guidance'),\n", + " ToTensord(keys='image')\n", + "]\n", + "\n", + "original_image = None\n", + "original_image_slice = None\n", + "for t in pre_transforms:\n", + " tname = type(t).__name__\n", + "\n", + " data = t(data)\n", + " image = data['image']\n", + " label = data.get('label')\n", + " guidance = data.get('guidance')\n", + "\n", + " print(\"{} => image shape: {}, label shape: {}\".format(\n", + " tname, image.shape, label.shape if label is not None else None))\n", + "\n", + " image = image if tname == 'Fetch2DSliced' else image[:, :, slice_idx] if tname in (\n", + " 'LoadImaged') else image[slice_idx, :, :]\n", + " label = None\n", + "\n", + " guidance = guidance if guidance else [np.roll(data['foreground'], 1).tolist(), []]\n", + " print('Guidance: {}'.format(guidance))\n", + "\n", + " show_image(image, label, guidance)\n", + " if tname == 'Fetch2DSliced':\n", + " slice_idx = 0\n", + " if tname == 'LoadImaged':\n", + " original_image = data['image']\n", + " if tname == 'AddChanneld':\n", + " original_image_slice = data['image']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# Evaluation\n", + "model_path = '/workspace/Data/models/deepgrow_2d.ts'\n", + "model = jit.load(model_path)\n", + "model.cuda()\n", + "model.eval()\n", + "\n", + "inputs = data['image'][None].cuda()\n", + "with torch.no_grad():\n", + " outputs = model(inputs)\n", + "outputs = outputs[0]\n", + "data['pred'] = outputs\n", + "\n", + "post_transforms = [\n", + " Activationsd(keys='pred', sigmoid=True),\n", + " AsDiscreted(keys='pred', threshold_values=True, logit_thresh=0.5),\n", + " ToNumpyd(keys='pred'),\n", + " RestoreLabeld(keys='pred', ref_image='image', mode='nearest'),\n", + "]\n", + "\n", + "for t in post_transforms:\n", + " tname = type(t).__name__\n", + "\n", + " data = t(data)\n", + " image = original_image if tname == 'RestoreLabeld' else data['image']\n", + " label = data['pred']\n", + " print(\"{} => image shape: {}, pred shape: {}\".format(tname, image.shape, label.shape))\n", + "\n", + " if tname in 'RestoreLabeld':\n", + " image = image[:, :, original_slice_idx]\n", + " label = label[0, :, :].detach().cpu().numpy() if torch.is_tensor(label) else label[original_slice_idx]\n", + " print(\"PLOT:: {} => image shape: {}, pred shape: {}; min: {}, max: {}, sum: {}\".format(\n", + " tname, image.shape, label.shape, np.min(label), np.max(label), np.sum(label)))\n", + " show_image(image, label)\n", + " else:\n", + " image = image[0, :, :].detach().cpu().numpy() if torch.is_tensor(image) else image[0]\n", + " label = label[0, :, :].detach().cpu().numpy() if torch.is_tensor(label) else label[0]\n", + " print(\"PLOT:: {} => image shape: {}, pred shape: {}; min: {}, max: {}, sum: {}\".format(\n", + " tname, image.shape, label.shape, np.min(label), np.max(label), np.sum(label)))\n", + " show_image(image, label)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/deepgrow/ignite/inference_3d.ipynb b/deepgrow/ignite/inference_3d.ipynb new file mode 100644 index 0000000000..1438694e15 --- /dev/null +++ b/deepgrow/ignite/inference_3d.ipynb @@ -0,0 +1,272 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "from torch import jit\n", + "\n", + "from monai.apps.deepgrow.transforms import (\n", + " AddGuidanceFromPointsd,\n", + " AddGuidanceSignald,\n", + " ResizeGuidanced,\n", + " RestoreLabeld,\n", + " SpatialCropGuidanced,\n", + ")\n", + "from monai.data import write_nifti\n", + "from monai.transforms import (\n", + " AsChannelFirstd,\n", + " Spacingd,\n", + " LoadImaged,\n", + " AddChanneld,\n", + " NormalizeIntensityd,\n", + " ToTensord,\n", + " ToNumpyd,\n", + " Activationsd,\n", + " AsDiscreted,\n", + " Resized\n", + ")\n", + "\n", + "max_epochs = 1\n", + "\n", + "\n", + "def draw_points(guidance, slice_idx):\n", + " if guidance is None:\n", + " return\n", + " colors = ['r+', 'b+']\n", + " for color, points in zip(colors, guidance):\n", + " for p in points:\n", + " if p[0] != slice_idx:\n", + " continue\n", + " p1 = p[-1]\n", + " p2 = p[-2]\n", + " plt.plot(p1, p2, color, 'MarkerSize', 30)\n", + "\n", + "\n", + "def show_image(image, label, guidance=None, slice_idx=None):\n", + " plt.figure(\"check\", (12, 6))\n", + " plt.subplot(1, 2, 1)\n", + " plt.title(\"image\")\n", + " plt.imshow(image, cmap=\"gray\")\n", + "\n", + " if label is not None:\n", + " masked = np.ma.masked_where(label == 0, label)\n", + " plt.imshow(masked, 'jet', interpolation='none', alpha=0.7)\n", + "\n", + " draw_points(guidance, slice_idx)\n", + " plt.colorbar()\n", + "\n", + " if label is not None:\n", + " plt.subplot(1, 2, 2)\n", + " plt.title(\"label\")\n", + " plt.imshow(label)\n", + " plt.colorbar()\n", + " # draw_points(guidance, slice_idx)\n", + " plt.show()\n", + "\n", + "\n", + "def print_data(data):\n", + " for k in data:\n", + " v = data[k]\n", + "\n", + " d = type(v)\n", + " if type(v) in (int, float, bool, str, dict, tuple):\n", + " d = v\n", + " elif hasattr(v, 'shape'):\n", + " d = v.shape\n", + "\n", + " if k in ('image_meta_dict', 'label_meta_dict'):\n", + " for m in data[k]:\n", + " print('{} Meta:: {} => {}'.format(k, m, data[k][m]))\n", + " else:\n", + " print('Data key: {} = {}'.format(k, d))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Pre Processing\n", + "roi_size = [256, 256]\n", + "model_size = [128, 192, 192]\n", + "pixdim = (1.0, 1.0, 1.0)\n", + "dimensions = 3\n", + "\n", + "data = {\n", + " 'image': '_image.nii.gz',\n", + " 'foreground': [[66, 180, 105], [66, 180, 145]],\n", + " 'background': [],\n", + "}\n", + "slice_idx = original_slice_idx = data['foreground'][0][2]\n", + "\n", + "pre_transforms = [\n", + " LoadImaged(keys='image'),\n", + " AsChannelFirstd(keys='image'),\n", + " Spacingd(keys='image', pixdim=pixdim, mode='bilinear'),\n", + " AddGuidanceFromPointsd(ref_image='image', guidance='guidance', foreground='foreground', background='background',\n", + " dimensions=dimensions),\n", + " AddChanneld(keys='image'),\n", + " SpatialCropGuidanced(keys='image', guidance='guidance', spatial_size=roi_size),\n", + " Resized(keys='image', spatial_size=model_size, mode='area'),\n", + " ResizeGuidanced(guidance='guidance', ref_image='image'),\n", + " NormalizeIntensityd(keys='image', subtrahend=208.0, divisor=388.0),\n", + " AddGuidanceSignald(image='image', guidance='guidance'),\n", + " ToTensord(keys='image')\n", + "]\n", + "\n", + "original_image = None\n", + "for t in pre_transforms:\n", + " tname = type(t).__name__\n", + " data = t(data)\n", + " image = data['image']\n", + " label = data.get('label')\n", + " guidance = data.get('guidance')\n", + "\n", + " print(\"{} => image shape: {}\".format(tname, image.shape))\n", + "\n", + " guidance = guidance if guidance else [np.roll(data['foreground'], 1).tolist(), []]\n", + " slice_idx = guidance[0][0][0] if guidance else slice_idx\n", + " print('Guidance: {}; Slice Idx: {}'.format(guidance, slice_idx))\n", + " if tname == 'Resized':\n", + " continue\n", + "\n", + " image = image[:, :, slice_idx] if tname in ('LoadImaged') else image[slice_idx] if tname in (\n", + " 'AsChannelFirstd', 'Spacingd', 'AddGuidanceFromPointsd') else image[0][slice_idx]\n", + " label = None\n", + "\n", + " show_image(image, label, guidance, slice_idx)\n", + " if tname == 'LoadImaged':\n", + " original_image = data['image']\n", + " if tname == 'AddChanneld':\n", + " original_image_slice = data['image']\n", + " if tname == 'SpatialCropGuidanced':\n", + " spatial_image = data['image']\n", + "\n", + "image = data['image']\n", + "label = data.get('label')\n", + "guidance = data.get('guidance')\n", + "for i in range(image.shape[1]):\n", + " print('Slice Idx: {}'.format(i))\n", + " # show_image(image[0][i], None, guidance, i)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Evaluation\n", + "model_path = '/workspace/Data/models/deepgrow_3d.ts'\n", + "model = jit.load(model_path)\n", + "model.cuda()\n", + "model.eval()\n", + "\n", + "inputs = data['image'][None].cuda()\n", + "with torch.no_grad():\n", + " outputs = model(inputs)\n", + "outputs = outputs[0]\n", + "data['pred'] = outputs\n", + "\n", + "post_transforms = [\n", + " Activationsd(keys='pred', sigmoid=True),\n", + " AsDiscreted(keys='pred', threshold_values=True, logit_thresh=0.5),\n", + " ToNumpyd(keys='pred'),\n", + " RestoreLabeld(keys='pred', ref_image='image', mode='nearest'),\n", + "]\n", + "\n", + "pred = None\n", + "for t in post_transforms:\n", + " tname = type(t).__name__\n", + "\n", + " data = t(data)\n", + " image = data['image']\n", + " label = data['pred']\n", + " print(\"{} => image shape: {}, pred shape: {}; slice_idx: {}\".format(tname, image.shape, label.shape, slice_idx))\n", + "\n", + " if tname in 'RestoreLabeld':\n", + " pred = label\n", + "\n", + " image = original_image[:, :, original_slice_idx]\n", + " label = label[original_slice_idx]\n", + " print(\"PLOT:: {} => image shape: {}, pred shape: {}; min: {}, max: {}, sum: {}\".format(\n", + " tname, image.shape, label.shape, np.min(label), np.max(label), np.sum(label)))\n", + " show_image(image, label)\n", + " elif tname == 'xToNumpyd':\n", + " for i in range(label.shape[1]):\n", + " img = image[0, i, :, :].detach().cpu().numpy() if torch.is_tensor(image) else image[0][i]\n", + " lab = label[0, i, :, :].detach().cpu().numpy() if torch.is_tensor(label) else label[0][i]\n", + " if np.sum(lab) > 0:\n", + " print(\"PLOT:: {} => image shape: {}, pred shape: {}; min: {}, max: {}, sum: {}\".format(\n", + " i, img.shape, lab.shape, np.min(lab), np.max(lab), np.sum(lab)))\n", + " show_image(img, lab)\n", + " else:\n", + " image = image[0, slice_idx, :, :].detach().cpu().numpy() if torch.is_tensor(image) else image[0][slice_idx]\n", + " label = label[0, slice_idx, :, :].detach().cpu().numpy() if torch.is_tensor(label) else label[0][slice_idx]\n", + " print(\"PLOT:: {} => image shape: {}, pred shape: {}; min: {}, max: {}, sum: {}\".format(\n", + " tname, image.shape, label.shape, np.min(label), np.max(label), np.sum(label)))\n", + " show_image(image, label)\n", + "\n", + "for i in range(pred.shape[0]):\n", + " image = original_image[:, :, i]\n", + " label = pred[i, :, :]\n", + " if np.sum(label) == 0:\n", + " continue\n", + "\n", + " print(\"Final PLOT:: {} => image shape: {}, pred shape: {}; min: {}, max: {}, sum: {}\".format(\n", + " i, image.shape, label.shape, np.min(label), np.max(label), np.sum(label)))\n", + " show_image(image, label)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "pred = data['pred']\n", + "meta_data = data['pred_meta_dict']\n", + "affine = meta_data.get(\"affine\", None)\n", + "\n", + "pred = np.moveaxis(pred, 0, -1)\n", + "print('Prediction NII shape: {}'.format(pred.shape))\n", + "\n", + "file_name = 'result_label.nii.gz'\n", + "write_nifti(pred, file_name=file_name)\n", + "print('Prediction saved at: {}'.format(file_name))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/deepgrow/ignite/prepare_dataset.py b/deepgrow/ignite/prepare_dataset.py new file mode 100644 index 0000000000..0cdca209f0 --- /dev/null +++ b/deepgrow/ignite/prepare_dataset.py @@ -0,0 +1,81 @@ +import argparse +import distutils.util +import json +import logging +import os +import sys + +from monai.apps.deepgrow.dataset import create_dataset + + +def prepare_datalist(args): + dimensions = args.dimensions + dataset_json = os.path.join(args.output, 'dataset.json') + if not os.path.exists(dataset_json): + logging.info('Processing dataset...') + with open(os.path.join(args.dataset_json)) as f: + datalist = json.load(f) + + datalist = create_dataset( + datalist=datalist[args.datalist_key], + base_dir=args.dataset_root, + output_dir=args.output, + dimension=dimensions, + pixdim=[1.0] * dimensions, + limit=args.limit, + relative_path=args.relative_path + ) + + with open(dataset_json, 'w') as fp: + json.dump(datalist, fp, indent=2) + else: + logging.info('Pre-load existing dataset.json') + + dataset_json = os.path.join(args.output, 'dataset.json') + with open(dataset_json) as f: + datalist = json.load(f) + logging.info('+++ Dataset File: {}'.format(dataset_json)) + logging.info('+++ Total Records: {}'.format(len(datalist))) + logging.info('') + + +def run(args): + for arg in vars(args): + logging.info('USING:: {} = {}'.format(arg, getattr(args, arg))) + logging.info("") + + if not os.path.exists(args.output): + logging.info('output path [{}] does not exist. creating it now.'.format(args.output)) + os.makedirs(args.output, exist_ok=True) + prepare_datalist(args) + + +def strtobool(val): + return bool(distutils.util.strtobool(val)) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument('-s', '--seed', type=int, default=42) + parser.add_argument('-dims', '--dimensions', type=int, default=2) + + parser.add_argument('-d', '--dataset_root', default='/workspace/data/MSD_Task09_Spleen') + parser.add_argument('-j', '--dataset_json', default='/workspace/data/MSD_Task09_Spleen/dataset.json') + parser.add_argument('-k', '--datalist_key', default='training') + + parser.add_argument('-o', '--output', default='/workspace/data/deepgrow/2D/MSD_Task09_Spleen') + parser.add_argument('-t', '--limit', type=int, default=0) + parser.add_argument('-r', '--relative_path', type=strtobool, default='false') + + args = parser.parse_args() + run(args) + + +if __name__ == "__main__": + logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format='[%(asctime)s.%(msecs)03d][%(levelname)5s] - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + main() diff --git a/deepgrow/ignite/stats.png b/deepgrow/ignite/stats.png new file mode 100644 index 0000000000..abae792137 Binary files /dev/null and b/deepgrow/ignite/stats.png differ diff --git a/deepgrow/ignite/train.py b/deepgrow/ignite/train.py new file mode 100644 index 0000000000..2483710240 --- /dev/null +++ b/deepgrow/ignite/train.py @@ -0,0 +1,366 @@ +import argparse +import distutils.util +import json +import logging +import os +import sys +import time + +import torch +import torch.distributed as dist + +from monai.apps.deepgrow.interaction import Interaction +from monai.apps.deepgrow.transforms import ( + SpatialCropForegroundd, + AddInitialSeedPointd, + FindDiscrepancyRegionsd, + AddRandomGuidanced, + AddGuidanceSignald, + FindAllValidSlicesd, +) +from monai.data import partition_dataset +from monai.data.dataloader import DataLoader +from monai.data.dataset import PersistentDataset +from monai.engines import SupervisedEvaluator +from monai.engines import SupervisedTrainer +from monai.handlers import ( + StatsHandler, + TensorBoardStatsHandler, + ValidationHandler, + LrScheduleHandler, + CheckpointSaver, + MeanDice) +from monai.inferers import SimpleInferer +from monai.losses import DiceLoss +from monai.networks.layers import Norm +from monai.networks.nets import BasicUNet, UNet +from monai.transforms import ( + Compose, + LoadImaged, + AddChanneld, + NormalizeIntensityd, + ToTensord, + ToNumpyd, + Activationsd, + AsDiscreted, + Resized, +) +from monai.utils import set_determinism +from handler import DeepgrowStatsHandler + + +def get_network(network, channels, dimensions): + if network == 'unet': + if channels == 16: + features = (16, 32, 64, 128, 256) + elif channels == 32: + features = (32, 64, 128, 256, 512) + else: + features = (64, 128, 256, 512, 1024) + logging.info('Using Unet with features: {}'.format(features)) + network = UNet(dimensions=dimensions, in_channels=3, out_channels=1, channels=features, strides=[2, 2, 2, 2], + norm=Norm.BATCH) + else: + if channels == 16: + features = (16, 32, 64, 128, 256, 16) + elif channels == 32: + features = (32, 64, 128, 256, 512, 32) + else: + features = (64, 128, 256, 512, 1024, 64) + logging.info('Using BasicUnet with features: {}'.format(features)) + network = BasicUNet(dimensions=dimensions, in_channels=3, out_channels=1, features=features) + return network + + +def get_pre_transforms(roi_size, model_size, dimensions): + t = [ + LoadImaged(keys=('image', 'label')), + AddChanneld(keys=('image', 'label')), + SpatialCropForegroundd(keys=('image', 'label'), source_key='label', spatial_size=roi_size), + Resized(keys=('image', 'label'), spatial_size=model_size, mode=('area', 'nearest')), + NormalizeIntensityd(keys='image', subtrahend=208.0, divisor=388.0) + ] + if dimensions == 3: + t.append(FindAllValidSlicesd(label='label', sids='sids')) + t.extend([ + AddInitialSeedPointd(label='label', guidance='guidance', sids='sids'), + AddGuidanceSignald(image='image', guidance='guidance'), + ToTensord(keys=('image', 'label')) + ]) + return Compose(t) + + +def get_click_transforms(): + return Compose([ + Activationsd(keys='pred', sigmoid=True), + ToNumpyd(keys=('image', 'label', 'pred', 'probability', 'guidance')), + FindDiscrepancyRegionsd(label='label', pred='pred', discrepancy='discrepancy', batched=True), + AddRandomGuidanced(guidance='guidance', discrepancy='discrepancy', probability='probability', batched=True), + AddGuidanceSignald(image='image', guidance='guidance', batched=True), + ToTensord(keys=('image', 'label')) + ]) + + +def get_post_transforms(): + return Compose([ + Activationsd(keys='pred', sigmoid=True), + AsDiscreted(keys='pred', threshold_values=True, logit_thresh=0.5) + ]) + + +def get_loaders(args, pre_transforms, train=True): + multi_gpu = args.multi_gpu + local_rank = args.local_rank + + dataset_json = os.path.join(args.input) + with open(dataset_json) as f: + datalist = json.load(f) + + total_d = len(datalist) + datalist = datalist[0:args.limit] if args.limit else datalist + total_l = len(datalist) + + if multi_gpu: + datalist = partition_dataset( + data=datalist, + num_partitions=dist.get_world_size(), + even_divisible=True, + shuffle=True, + seed=args.seed + )[local_rank] + + if train: + train_datalist, val_datalist = partition_dataset( + datalist, + ratios=[args.split, (1 - args.split)], + shuffle=True, + seed=args.seed) + + train_ds = PersistentDataset(train_datalist, pre_transforms, cache_dir=args.cache_dir) + train_loader = DataLoader( + train_ds, + batch_size=args.batch, + shuffle=True, + num_workers=16) + logging.info('{}:: Total Records used for Training is: {}/{}/{}'.format( + local_rank, len(train_ds), total_l, total_d)) + else: + train_loader = None + val_datalist = datalist + + val_ds = PersistentDataset(val_datalist, pre_transforms, cache_dir=args.cache_dir) + val_loader = DataLoader(val_ds, batch_size=args.batch, num_workers=8) + logging.info('{}:: Total Records used for Validation is: {}/{}/{}'.format( + local_rank, len(val_ds), total_l, total_d)) + + return train_loader, val_loader + + +def create_trainer(args): + set_determinism(seed=args.seed) + + multi_gpu = args.multi_gpu + local_rank = args.local_rank + if multi_gpu: + dist.init_process_group(backend="nccl", init_method="env://") + device = torch.device("cuda:{}".format(local_rank)) + torch.cuda.set_device(device) + else: + device = torch.device("cuda" if args.use_gpu else "cpu") + + pre_transforms = get_pre_transforms(args.roi_size, args.model_size, args.dimensions) + click_transforms = get_click_transforms() + post_transform = get_post_transforms() + + train_loader, val_loader = get_loaders(args, pre_transforms) + + # define training components + network = get_network(args.network, args.channels, args.dimensions).to(device) + if multi_gpu: + network = torch.nn.parallel.DistributedDataParallel(network, device_ids=[local_rank], output_device=local_rank) + + if args.resume: + logging.info('{}:: Loading Network...'.format(local_rank)) + map_location = {"cuda:0": "cuda:{}".format(local_rank)} + network.load_state_dict(torch.load(args.model_filepath, map_location=map_location)) + + # define event-handlers for engine + val_handlers = [ + StatsHandler(output_transform=lambda x: None), + TensorBoardStatsHandler(log_dir=args.output, output_transform=lambda x: None), + DeepgrowStatsHandler(log_dir=args.output, tag_name='val_dice', image_interval=args.image_interval), + CheckpointSaver(save_dir=args.output, save_dict={"net": network}, save_key_metric=True, save_final=True, + save_interval=args.save_interval, final_filename='model.pt') + ] + val_handlers = val_handlers if local_rank == 0 else None + + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=val_loader, + network=network, + iteration_update=Interaction( + transforms=click_transforms, + max_interactions=args.max_val_interactions, + key_probability='probability', + train=False), + inferer=SimpleInferer(), + post_transform=post_transform, + key_val_metric={ + "val_dice": MeanDice( + include_background=False, + output_transform=lambda x: (x["pred"], x["label"]) + ) + }, + val_handlers=val_handlers + ) + + loss_function = DiceLoss(sigmoid=True, squared_pred=True) + optimizer = torch.optim.Adam(network.parameters(), args.learning_rate) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.1) + + train_handlers = [ + LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), + ValidationHandler(validator=evaluator, interval=args.val_freq, epoch_level=True), + StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), + TensorBoardStatsHandler(log_dir=args.output, tag_name="train_loss", output_transform=lambda x: x["loss"]), + CheckpointSaver(save_dir=args.output, save_dict={"net": network, "opt": optimizer, "lr": lr_scheduler}, + save_interval=args.save_interval * 2, save_final=True, final_filename='checkpoint.pt'), + ] + train_handlers = train_handlers if local_rank == 0 else train_handlers[:2] + + trainer = SupervisedTrainer( + device=device, + max_epochs=args.epochs, + train_data_loader=train_loader, + network=network, + iteration_update=Interaction( + transforms=click_transforms, + max_interactions=args.max_train_interactions, + key_probability='probability', + train=True), + optimizer=optimizer, + loss_function=loss_function, + inferer=SimpleInferer(), + post_transform=post_transform, + amp=args.amp, + key_train_metric={ + "train_dice": MeanDice( + include_background=False, + output_transform=lambda x: (x["pred"], x["label"]) + ) + }, + train_handlers=train_handlers, + ) + return trainer + + +def run(args): + args.roi_size = json.loads(args.roi_size) + args.model_size = json.loads(args.model_size) + + if args.local_rank == 0: + for arg in vars(args): + logging.info('USING:: {} = {}'.format(arg, getattr(args, arg))) + print("") + + if args.export: + logging.info('{}:: Loading PT Model from: {}'.format(args.local_rank, args.input)) + device = torch.device("cuda" if args.use_gpu else "cpu") + network = get_network(args.network, args.channels, args.dimensions).to(device) + + map_location = {"cuda:0": "cuda:{}".format(args.local_rank)} + network.load_state_dict(torch.load(args.input, map_location=map_location)) + + logging.info('{}:: Saving TorchScript Model'.format(args.local_rank)) + model_ts = torch.jit.script(network) + torch.jit.save(model_ts, os.path.join(args.output)) + return + + if not os.path.exists(args.output): + logging.info('output path [{}] does not exist. creating it now.'.format(args.output)) + os.makedirs(args.output, exist_ok=True) + + trainer = create_trainer(args) + + start_time = time.time() + trainer.run() + end_time = time.time() + + logging.info('Total Training Time {}'.format(end_time - start_time)) + if args.local_rank == 0: + logging.info('{}:: Saving Final PT Model'.format(args.local_rank)) + torch.save(trainer.network.state_dict(), os.path.join(args.output, 'model-final.pt')) + + if not args.multi_gpu: + logging.info('{}:: Saving TorchScript Model'.format(args.local_rank)) + model_ts = torch.jit.script(trainer.network) + torch.jit.save(model_ts, os.path.join(args.output, 'model-final.ts')) + + if args.multi_gpu: + dist.destroy_process_group() + + +def strtobool(val): + return bool(distutils.util.strtobool(val)) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument('-s', '--seed', type=int, default=23) + parser.add_argument('--dimensions', type=int, default=2) + + parser.add_argument('-n', '--network', default='bunet', choices=['unet', 'bunet']) + parser.add_argument('-c', '--channels', type=int, default=32) + parser.add_argument('-i', '--input', default='/workspace/data/deepgrow/2D/MSD_Task09_Spleen/dataset.json') + parser.add_argument('-o', '--output', default='output') + + parser.add_argument('-g', '--use_gpu', type=strtobool, default='true') + parser.add_argument('-a', '--amp', type=strtobool, default='false') + + parser.add_argument('-e', '--epochs', type=int, default=100) + parser.add_argument('-b', '--batch', type=int, default=8) + parser.add_argument('-x', '--split', type=float, default=0.9) + parser.add_argument('-t', '--limit', type=int, default=0) + parser.add_argument('--cache_dir', type=str, default=None) + + parser.add_argument('-r', '--resume', type=strtobool, default='false') + parser.add_argument('-m', '--model_path', default="output/model.pt") + parser.add_argument('--roi_size', default="[256, 256]") + parser.add_argument('--model_size', default="[256, 256]") + + parser.add_argument('-f', '--val_freq', type=int, default=1) + parser.add_argument('-lr', '--learning_rate', type=float, default=0.0001) + parser.add_argument('-it', '--max_train_interactions', type=int, default=15) + parser.add_argument('-iv', '--max_val_interactions', type=int, default=5) + + parser.add_argument('--save_interval', type=int, default=3) + parser.add_argument('--image_interval', type=int, default=1) + parser.add_argument('--multi_gpu', type=strtobool, default='false') + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument('--export', type=strtobool, default='false') + + args = parser.parse_args() + run(args) + + +if __name__ == "__main__": + logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format='[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + main() + +''' +# Single GPU (it will also export) +python train.py + +# Multi GPU (run export separate) +python -m torch.distributed.launch \ + --nproc_per_node=`nvidia-smi -L | wc -l` \ + --nnodes=1 --node_rank=0 --master_addr="localhost" --master_port=1234 \ + -m train --multi_gpu true -e 100 + +python train.py --export +''' diff --git a/deepgrow/ignite/train_3d.py b/deepgrow/ignite/train_3d.py new file mode 100644 index 0000000000..c1fbf89fc8 --- /dev/null +++ b/deepgrow/ignite/train_3d.py @@ -0,0 +1,69 @@ +import argparse +import distutils.util +import logging +import sys + +import train + + +def strtobool(val): + return bool(distutils.util.strtobool(val)) + + +if __name__ == "__main__": + logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format='[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + + parser = argparse.ArgumentParser() + + parser.add_argument('-s', '--seed', type=int, default=23) + parser.add_argument('--dimensions', type=int, default=3) + + parser.add_argument('-n', '--network', default='bunet', choices=['unet', 'bunet']) + parser.add_argument('-c', '--channels', type=int, default=32) + parser.add_argument('-i', '--input', default='/workspace/data/deepgrow/3D/MSD_Task09_Spleen/dataset.json') + parser.add_argument('-o', '--output', default='output3D') + + parser.add_argument('-g', '--use_gpu', type=strtobool, default='true') + parser.add_argument('-a', '--amp', type=strtobool, default='false') + + parser.add_argument('-e', '--epochs', type=int, default=200) + parser.add_argument('-b', '--batch', type=int, default=1) + parser.add_argument('-x', '--split', type=float, default=0.9) + parser.add_argument('-t', '--limit', type=int, default=0) + parser.add_argument('--cache_dir', type=str, default=None) + + parser.add_argument('-r', '--resume', type=strtobool, default='false') + parser.add_argument('-m', '--model_path', default="output3D/model.pt") + parser.add_argument('--roi_size', default="[128, 192, 192]") + parser.add_argument('--model_size', default="[128, 192, 192]") + + parser.add_argument('-f', '--val_freq', type=int, default=1) + parser.add_argument('-lr', '--learning_rate', type=float, default=0.0001) + parser.add_argument('-it', '--max_train_interactions', type=int, default=15) + parser.add_argument('-iv', '--max_val_interactions', type=int, default=20) + + parser.add_argument('--save_interval', type=int, default=20) + parser.add_argument('--image_interval', type=int, default=5) + parser.add_argument('--multi_gpu', type=strtobool, default='false') + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument('--export', type=strtobool, default='false') + + args = parser.parse_args() + train.run(args) + +''' +# Single GPU (it will also export) +python train_3d.py + +# Multi GPU (run export separate) +python -m torch.distributed.launch \ + --nproc_per_node=`nvidia-smi -L | wc -l` \ + --nnodes=1 --node_rank=0 --master_addr="localhost" --master_port=1234 \ + -m train_3d --multi_gpu true -e 100 + +python train.py --export +''' diff --git a/deepgrow/ignite/validate.py b/deepgrow/ignite/validate.py new file mode 100644 index 0000000000..127a4c5041 --- /dev/null +++ b/deepgrow/ignite/validate.py @@ -0,0 +1,147 @@ +import argparse +import distutils.util +import json +import logging +import os +import sys +import time + +import torch + +import train +from handler import DeepgrowStatsHandler, SegmentationSaver +from monai.apps.deepgrow.interaction import Interaction +from monai.engines import SupervisedEvaluator +from monai.handlers import ( + StatsHandler, + TensorBoardStatsHandler, + MeanDice) +from monai.inferers import SimpleInferer +from monai.utils import set_determinism + + +def create_validator(args, click): + set_determinism(seed=args.seed) + + device = torch.device("cuda" if args.use_gpu else "cpu") + + pre_transforms = train.get_pre_transforms(args.roi_size, args.model_size, args.dimensions) + click_transforms = train.get_click_transforms() + post_transform = train.get_post_transforms() + + # define training components + network = train.get_network(args.network, args.channels, args.dimensions).to(device) + + logging.info('Loading Network...') + map_location = {"cuda:0": "cuda:{}".format(args.local_rank)} + + checkpoint = torch.load(args.model_path, map_location=map_location) + network.load_state_dict(checkpoint) + network.eval() + + # define event-handlers for engine + _, val_loader = train.get_loaders(args, pre_transforms, train=False) + fold_size = int(len(val_loader.dataset) / args.batch / args.folds) if args.folds else 0 + logging.info('Using Fold-Size: {}'.format(fold_size)) + + val_handlers = [ + StatsHandler(output_transform=lambda x: None), + TensorBoardStatsHandler(log_dir=args.output, output_transform=lambda x: None), + DeepgrowStatsHandler( + log_dir=args.output, + tag_name=f'clicks_{click}_val_dice', + fold_size=int(len(val_loader.dataset) / args.batch / args.folds) if args.folds else 0 + ), + ] + if args.save_seg: + val_handlers.append(SegmentationSaver(output_dir=os.path.join(args.output, f'clicks_{click}_images'))) + + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=val_loader, + network=network, + iteration_update=Interaction( + transforms=click_transforms, + max_interactions=click, + train=False), + inferer=SimpleInferer(), + post_transform=post_transform, + val_handlers=val_handlers, + key_val_metric={ + f'clicks_{click}_val_dice': MeanDice( + include_background=False, + output_transform=lambda x: (x["pred"], x["label"]) + ) + } + ) + return evaluator + + +def run(args): + args.roi_size = json.loads(args.roi_size) + args.model_size = json.loads(args.model_size) + + if args.local_rank == 0: + for arg in vars(args): + logging.info('USING:: {} = {}'.format(arg, getattr(args, arg))) + print("") + + if not os.path.exists(args.output): + logging.info('output path [{}] does not exist. creating it now.'.format(args.output)) + os.makedirs(args.output, exist_ok=True) + + clicks = json.loads(args.max_val_interactions) + for click in clicks: + logging.info('+++++++++++++++++++++++++++++++++++++++++++++++++++++') + logging.info(' CLICKS = {}'.format(click)) + logging.info('+++++++++++++++++++++++++++++++++++++++++++++++++++++') + evaluator = create_validator(args, click) + + start_time = time.time() + evaluator.run() + end_time = time.time() + + logging.info('Total Run Time {}'.format(end_time - start_time)) + + +def strtobool(val): + return bool(distutils.util.strtobool(val)) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument('-s', '--seed', type=int, default=23) + parser.add_argument('--dimensions', type=int, default=2) + + parser.add_argument('-n', '--network', default='bunet', choices=['unet', 'bunet']) + parser.add_argument('-c', '--channels', type=int, default=32) + parser.add_argument('-f', '--folds', type=int, default=10) + + parser.add_argument('-i', '--input', default='/workspace/data/deepgrow/2D/MSD_Task09_Spleen/dataset.json') + parser.add_argument('-o', '--output', default='eval') + parser.add_argument('--save_seg', type=strtobool, default='false') + parser.add_argument('--cache_dir', type=str, default=None) + + parser.add_argument('-g', '--use_gpu', type=strtobool, default='true') + parser.add_argument('-b', '--batch', type=int, default=1) + parser.add_argument('-t', '--limit', type=int, default=20) + parser.add_argument('-m', '--model_path', default="output/model.pt") + parser.add_argument('--roi_size', default="[256, 256]") + parser.add_argument('--model_size', default="[256, 256]") + + parser.add_argument('-iv', '--max_val_interactions', default="[0,1,2,5,10,15]") + parser.add_argument('--multi_gpu', type=strtobool, default='false') + parser.add_argument("--local_rank", type=int, default=0) + + args = parser.parse_args() + run(args) + + +if __name__ == "__main__": + logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format='[%(asctime)s.%(msecs)03d][%(levelname)5s] - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + main() diff --git a/deepgrow/ignite/validate_3d.py b/deepgrow/ignite/validate_3d.py new file mode 100644 index 0000000000..c4121c8c0c --- /dev/null +++ b/deepgrow/ignite/validate_3d.py @@ -0,0 +1,49 @@ +import argparse +import distutils.util +import logging +import sys + +import validate + + +def strtobool(val): + return bool(distutils.util.strtobool(val)) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument('-s', '--seed', type=int, default=23) + parser.add_argument('--dimensions', type=int, default=3) + + parser.add_argument('-n', '--network', default='bunet', choices=['unet', 'bunet']) + parser.add_argument('-c', '--channels', type=int, default=32) + parser.add_argument('-f', '--folds', type=int, default=10) + + parser.add_argument('-i', '--input', default='/workspace/data/deepgrow/3D/MSD_Task09_Spleen/dataset.json') + parser.add_argument('-o', '--output', default='eval3D') + parser.add_argument('--save_seg', type=strtobool, default='false') + parser.add_argument('--cache_dir', type=str, default=None) + + parser.add_argument('-g', '--use_gpu', type=strtobool, default='true') + parser.add_argument('-b', '--batch', type=int, default=1) + parser.add_argument('-t', '--limit', type=int, default=20) + parser.add_argument('-m', '--model_path', default="output3D/model.pt") + parser.add_argument('--roi_size', default="[256, 256, 256]") + parser.add_argument('--model_size', default="[128, 128, 128]") + + parser.add_argument('-iv', '--max_val_interactions', default="[0,1,2,5,10,15]") + parser.add_argument('--multi_gpu', type=strtobool, default='false') + parser.add_argument("--local_rank", type=int, default=0) + + args = parser.parse_args() + validate.run(args) + + +if __name__ == "__main__": + logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format='[%(asctime)s.%(msecs)03d][%(levelname)5s] - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + main()