diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md new file mode 100644 index 0000000000..9c2dd97bb2 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/question.md @@ -0,0 +1,12 @@ +--- +name: Question +about: Question relating to MONAI +title: '' +labels: '' +assignees: '' +--- + +**Please use MONAI's Discussions tab** +For questions relating to MONAI usage, please do not create an issue. + +Instead, use [MONAI's GitHub Discussions tab](https://github.com/Project-MONAI/MONAI/discussions). This can be found next to Issues and Pull Requests along the top of our repository. \ No newline at end of file diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 1dcc7675f0..003a746de4 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -34,7 +34,7 @@ jobs: which python python -m pip install --upgrade pip wheel python -m pip uninstall -y torch torchvision - python -m pip install torch==1.7.0 torchvision==0.8.1 + python -m pip install torch==1.7.1 torchvision==0.8.2 python -m pip install -r requirements-dev.txt - name: Run integration tests run: | diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 4db50b2723..8e92ea0ed7 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -80,13 +80,13 @@ jobs: - if: runner.os == 'windows' name: Install torch cpu from pytorch.org (Windows only) run: | - python -m pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html + python -m pip install torch==1.7.1+cpu torchvision==0.8.2+cpu -f https://download.pytorch.org/whl/torch_stable.html # min. requirements for windows instances python -c "f=open('requirements-dev.txt', 'r'); txt=f.readlines(); f.close(); print(txt); f=open('requirements-dev.txt', 'w'); f.writelines(txt[1:12]); f.close()" - name: Install the dependencies run: | - python -m pip install torch==1.7.0 - python -m pip install torchvision==0.8.1 + python -m pip install torch==1.7.1 + python -m pip install torchvision==0.8.2 cat "requirements-dev.txt" python -m pip install -r requirements-dev.txt python -m pip list @@ -108,7 +108,7 @@ jobs: fail-fast: false matrix: os: [windows-latest, macOS-latest, ubuntu-latest] - timeout-minutes: 20 + timeout-minutes: 40 steps: - uses: actions/checkout@v2 - name: Set up Python 3.8 @@ -134,11 +134,11 @@ jobs: - if: runner.os == 'windows' name: Install torch cpu from pytorch.org (Windows only) run: | - python -m pip install torch==1.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + python -m pip install torch==1.7.1+cpu -f https://download.pytorch.org/whl/torch_stable.html - name: Install the dependencies run: | # min. requirements - python -m pip install torch==1.7.0 + python -m pip install torch==1.7.1 python -m pip install -r requirements-min.txt python -m pip list BUILD_MONAI=0 python setup.py develop # no compile of extensions @@ -173,7 +173,7 @@ jobs: pytorch: "-h" base: "nvcr.io/nvidia/pytorch:20.07-py3" - environment: PT17+CUDA102 - pytorch: "torch==1.7.0 torchvision==0.8.1" + pytorch: "torch==1.7.1 torchvision==0.8.2" base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04" - environment: PT17+CUDA110 # we explicitly set pytorch to -h to avoid pip install error diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index b8b02e6ace..ed5d560861 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -36,7 +36,7 @@ jobs: which python python -m pip install --upgrade pip wheel python -m pip uninstall -y torch torchvision - python -m pip install torch==1.7.0 torchvision==0.8.1 + python -m pip install torch==1.7.1 torchvision==0.8.2 python -m pip install -r requirements-dev.txt - name: Run unit tests report coverage run: | @@ -82,7 +82,7 @@ jobs: - name: Install the dependencies run: | python -m pip install --upgrade pip wheel - python -m pip install torch==1.7.0 torchvision==0.8.1 + python -m pip install torch==1.7.1 torchvision==0.8.2 python -m pip install -r requirements-dev.txt - name: Run quick tests CPU ubuntu run: | @@ -151,7 +151,8 @@ jobs: run: | docker build -t localhost:5000/local_monai:latest -f Dockerfile . docker push localhost:5000/local_monai:latest - docker tag localhost:5000/local_monai:latest projectmonai/monai:latest + sed -i '/flake/d' requirements-dev.txt + docker build -t projectmonai/monai:latest -f Dockerfile . docker login -u projectmonai -p ${{ secrets.DOCKER_PW }} docker push projectmonai/monai:latest docker logout diff --git a/.github/workflows/weekly-preview.yml b/.github/workflows/weekly-preview.yml index b6dd43dbe0..54e43d6968 100644 --- a/.github/workflows/weekly-preview.yml +++ b/.github/workflows/weekly-preview.yml @@ -29,7 +29,7 @@ jobs: git config user.email "monai.miccai2019@gmail.com" git add setup.cfg monai/__init__.py git commit -m "Weekly build at $HEAD_COMMIT_ID" - git tag 0.4.dev$(date +'%y%U') + git tag 0.5.dev$(date +'%y%U') python setup.py sdist bdist_wheel - name: Publish to PyPI diff --git a/CHANGELOG.md b/CHANGELOG.md index 7fe6627a36..56e65a7d92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,77 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). ## [Unreleased] +## [0.4.0] - 2020-12-15 +### Added +* Overview document for [feature highlights in v0.4.0](https://github.com/Project-MONAI/MONAI/blob/master/docs/source/highlights.md) +* Torchscript support for the net modules +* New networks and layers: + * Discrete Gaussian kernels + * Hilbert transform and envelope detection + * Swish and mish activation + * Acti-norm-dropout block + * Upsampling layer + * Autoencoder, Variational autoencoder + * FCNet +* Support of initialisation from pretrained weights for densenet, senet, multichannel AHNet +* Layer-wise learning rate API +* New model metrics and event handlers based on occlusion sensitivity, confusion matrix, surface distance +* CAM/GradCAM/GradCAM++ +* File format-agnostic image loader APIs with Nibabel, ITK readers +* Enhancements for dataset partition, cross-validation APIs +* New data APIs: + * LMDB-based caching dataset + * Cache-N-transforms dataset + * Iterable dataset + * Patch dataset +* Weekly PyPI release +* Fully compatible with PyTorch 1.7 +* CI/CD enhancements: + * Skipping, speed up, fail fast, timed, quick tests + * Distributed training tests + * Performance profiling utilities +* New tutorials and demos: + * Autoencoder, VAE tutorial + * Cross-validation demo + * Model interpretability tutorial + * COVID-19 Lung CT segmentation challenge open-source baseline + * Threadbuffer demo + * Dataset partitioning tutorial + * Layer-wise learning rate demo + * [MONAI Bootcamp 2020](https://github.com/Project-MONAI/MONAIBootcamp2020) + +### Changed +* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:20.10-py3` from `nvcr.io/nvidia/pytorch:20.08-py3` + +#### Backwards Incompatible Changes +* `monai.apps.CVDecathlonDataset` is extended to a generic `monai.apps.CrossValidation` with an `dataset_cls` option +* Cache dataset now requires a `monai.transforms.Compose` instance as the transform argument +* Model checkpoint file name extensions changed from `.pth` to `.pt` +* Readers' `get_spatial_shape` returns a numpy array instead of list +* Decoupled postprocessing steps such as `sigmoid`, `to_onehot_y`, `mutually_exclusive`, `logit_thresh` from metrics and event handlers, +the postprocessing steps should be used before calling the metrics methods +* `ConfusionMatrixMetric` and `DiceMetric` computation now returns an additional `not_nans` flag to indicate valid results +* `UpSample` optional `mode` now supports `"deconv"`, `"nontrainable"`, `"pixelshuffle"`; `interp_mode` is only used when `mode` is `"nontrainable"` +* `SegResNet` optional `upsample_mode` now supports `"deconv"`, `"nontrainable"`, `"pixelshuffle"` +* `monai.transforms.Compose` class inherits `monai.transforms.Transform` +* In `Rotate`, `Rotated`, `RandRotate`, `RandRotated` transforms, the `angle` related parameters are interpreted as angles in radians instead of degrees. +* `SplitChannel` and `SplitChanneld` moved from `transforms.post` to `transforms.utility` + +### Removed +* Support of PyTorch 1.4 + +### Fixed +* Enhanced loss functions for stability and flexibility +* Sliding window inference memory and device issues +* Revised transforms: + * Normalize intensity datatype and normalizer types + * Padding modes for zoom + * Crop returns coordinates + * Select items transform + * Weighted patch sampling + * Option to keep aspect ratio for zoom +* Various CI/CD issues + ## [0.3.0] - 2020-10-02 ### Added * Overview document for [feature highlights in v0.3.0](https://github.com/Project-MONAI/MONAI/blob/master/docs/source/highlights.md) @@ -102,7 +173,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. [highlights]: https://github.com/Project-MONAI/MONAI/blob/master/docs/source/highlights.md -[Unreleased]: https://github.com/Project-MONAI/MONAI/compare/0.3.0...HEAD +[Unreleased]: https://github.com/Project-MONAI/MONAI/compare/0.4.0...HEAD +[0.4.0]: https://github.com/Project-MONAI/MONAI/compare/0.3.0...0.4.0 [0.3.0]: https://github.com/Project-MONAI/MONAI/compare/0.2.0...0.3.0 [0.2.0]: https://github.com/Project-MONAI/MONAI/compare/0.1.0...0.2.0 [0.1.0]: https://github.com/Project-MONAI/MONAI/commits/0.1.0 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d9b610ee64..81c5b32174 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -32,6 +32,8 @@ _Pull request early_ We encourage you to create pull requests early. It helps us track the contributions under development, whether they are ready to be merged or not. Change your pull request's title to begin with `[WIP]` until it is ready for formal review. +Please note that, as per PyTorch, MONAI uses American English spelling. This means classes and variables should be: normali**z**e, visuali**z**e, colo~~u~~r, etc. + ### Preparing pull requests To ensure the code quality, MONAI relies on several linting tools ([flake8 and its plugins](https://gitlab.com/pycqa/flake8), [black](https://github.com/psf/black), [isort](https://github.com/timothycrosley/isort)), static type analysis tools ([mypy](https://github.com/python/mypy), [pytype](https://github.com/google/pytype)), as well as a set of unit/integration tests. diff --git a/README.md b/README.md index bda37a31b7..1741f2c518 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ For guidance on making a contribution to MONAI, see the [contributing guidelines ## Community Join the conversation on Twitter [@ProjectMONAI](https://twitter.com/ProjectMONAI) or join our [Slack channel](https://forms.gle/QTxJq3hFictp31UM9). -Ask and answer questions over on the [PyTorch Forums](https://discuss.pytorch.org/) or [StackOverflow](https://stackoverflow.com/tags/monai). Make sure to tag @monai. +Ask and answer questions over on [MONAI's GitHub Discussions tab](https://github.com/Project-MONAI/MONAI/discussions). ## Links - Website: https://monai.io/ diff --git a/docs/images/arch_modules_v0.3.png b/docs/images/arch_modules_v0.3.png deleted file mode 100644 index 768ff9de60..0000000000 Binary files a/docs/images/arch_modules_v0.3.png and /dev/null differ diff --git a/docs/images/arch_modules_v0.4.png b/docs/images/arch_modules_v0.4.png new file mode 100644 index 0000000000..ec5a7d9d21 Binary files /dev/null and b/docs/images/arch_modules_v0.4.png differ diff --git a/docs/images/cam.png b/docs/images/cam.png new file mode 100644 index 0000000000..3a8dcfed1d Binary files /dev/null and b/docs/images/cam.png differ diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 475a44de64..2962f725d8 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -34,12 +34,24 @@ ROC AUC metrics handler :members: -Confusion Matrix metrics handler +Confusion matrix metrics handler -------------------------------- .. autoclass:: ConfusionMatrix :members: +Hausdorff distance metrics handler +---------------------------------- +.. autoclass:: HausdorffDistance + :members: + + +Surface distance metrics handler +-------------------------------- +.. autoclass:: SurfaceDistance + :members: + + Metric logger ------------- .. autoclass:: MetricLogger diff --git a/docs/source/highlights.md b/docs/source/highlights.md index 571f20db9f..d8fe5c2ff9 100644 --- a/docs/source/highlights.md +++ b/docs/source/highlights.md @@ -1,4 +1,4 @@ -# Modules in v0.3.0 +# Modules in v0.4.0 MONAI aims at supporting deep learning in medical image analysis at multiple granularities. This figure shows a typical example of the end-to-end workflow in medical deep learning area: @@ -12,12 +12,13 @@ The design principle of MONAI is to provide flexible and light APIs for users wi 4. Researchers contribute implementations based on the state-of-the-art for the latest research challenges, including COVID-19 image analysis, Model Parallel, etc. The overall architecture and modules are shown in the following figure: -![image](../images/arch_modules_v0.3.png) +![image](../images/arch_modules_v0.4.png) The rest of this page provides more details for each module. * [Data I/O, processing and augmentation](#medical-image-data-i-o-processing-and-augmentation) * [Datasets](#datasets) * [Loss functions](#losses) +* [Optimizers](#optimizers) * [Network architectures](#network-architectures) * [Evaluation](#evaluation) * [Visualization](#visualization) @@ -119,7 +120,12 @@ The design of MONAI transforms emphasis code readability and usability. It works For more details, please check out the tutorial: [integrate 3rd party transforms into MONAI program](https://github.com/Project-MONAI/tutorials/blob/master/modules/integrate_3rd_party_transforms.ipynb). ### 10. IO factory for medical image formats -Many popular image formats exist in the medical domain, and they are quite different with rich meta data information. To easily handle different medical image formats in the same pipeline, MONAI provides `LoadImage` transform, which uses `ITKReader` as the default image reader and also supports to register other readers, like `NibabelReader`, `NumpyReader`, and `PILReader`. The `ImageReader` API is quite straight-forward, users can easily extend for their own customized image readers. +Many popular image formats exist in the medical domain, and they are quite different with rich metadata information. To easily handle different medical image formats in the same pipeline, [MONAI provides `LoadImage` transform](https://github.com/Project-MONAI/tutorials/blob/master/modules/load_medical_images.ipynb), which can automatically choose image readers based on the supported suffixes and in below priority order: +- User-specified reader at runtime when call this loader. +- Registered readers from the latest to the first in list. +- Default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), (npz, npy -> NumpyReader), (others -> ITKReader). + +The `ImageReader` API is quite straight-forward, users can easily extend for their own customized image readers. With these pre-defined image readers, MONAI can load images in formats: `NIfTI`, `DICOM`, `PNG`, `JPG`, `BMP`, `NPY/NPZ`, etc. @@ -127,17 +133,17 @@ With these pre-defined image readers, MONAI can load images in formats: `NIfTI`, ### 1. Cache IO and transforms data to accelerate training Users often need to train the model with many (potentially thousands of) epochs over the data to achieve the desired model quality. A native PyTorch implementation may repeatedly load data and run the same preprocessing steps for every epoch during training, which can be time-consuming and unnecessary, especially when the medical image volumes are large. -MONAI provides a multi-threads `CacheDataset` to accelerate these transformation steps during training by storing the intermediate outcomes before the first randomized transform in the transform chain. Enabling this feature could potentially give 10x training speedups in the [Datasets experiment](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/dataset_type_performance.ipynb). +MONAI provides a multi-threads `CacheDataset` and `LMDBDataset` to accelerate these transformation steps during training by storing the intermediate outcomes before the first randomized transform in the transform chain. Enabling this feature could potentially give 10x training speedups in the [Datasets experiment](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/dataset_type_performance.ipynb). ![image](../images/cache_dataset.png) ### 2. Cache intermediate outcomes into persistent storage -The `PersistentDataset` is similar to the CacheDataset, where the intermediate cache values are persisted to disk storage for rapid retrieval between experimental runs (as is the case when tuning hyperparameters), or when the entire data set size exceeds available memory. The `PersistentDataset` could achieve similar performance when comparing to `CacheDataset` in [Datasets experiment](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/dataset_type_performance.ipynb). +The `PersistentDataset` is similar to the CacheDataset, where the intermediate cache values are persisted to disk storage or LMDB for rapid retrieval between experimental runs (as is the case when tuning hyperparameters), or when the entire data set size exceeds available memory. The `PersistentDataset` could achieve similar performance when comparing to `CacheDataset` in [Datasets experiment](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/dataset_type_performance.ipynb). ![image](../images/datasets_speed.png) ### 3. SmartCache mechanism for big datasets During training with very big volume dataset, an efficient approach is to only train with a subset of the dataset in an epoch and dynamically replace part of the subset in every epoch. It's the `SmartCache` mechanism in [NVIDIA Clara-train SDK](https://docs.nvidia.com/clara/tlt-mi/clara-train-sdk-v3.0/nvmidl/additional_features/smart_cache.html#smart-cache). -MONAI provides a PyTorch version `SmartCache` as `SmartCacheDataset`. In each epoch, only the items in the cache are used for training, at the same time, another thread is preparing replacement items by applying the transform sequence to items not in cache. Once one epoch is completed, `SmartCache` replaces the same number of items with replacement items. +MONAI provides a PyTorch version `SmartCache` as `SmartCacheDataset`. In each epoch, only the items in the cache are used for training, at the same time, another thread is preparing replacement items by applying the transform sequence to items not in the cache. Once one epoch is completed, `SmartCache` replaces the same number of items with replacement items. For example, if we have 5 images: `[image1, image2, image3, image4, image5]`, and `cache_num=4`, `replace_rate=0.25`. So the actual training images cached and replaced for every epoch are as below: ``` @@ -165,7 +171,18 @@ class DatasetB(Dataset): dataset = ZipDataset([DatasetA(), DatasetB()], transform) ``` -### 5. Predefined Datasets for public medical data +### 5. PatchDataset +`monai.data.PatchDataset` provides a flexible API to combine both image- and patch-level preprocessing: +```python +image_dataset = Dataset(input_images, transforms=image_transforms) +patch_dataset = PatchDataset( + dataset=image_dataset, patch_func=sampler, + samples_per_image=n_samples, transform=patch_transforms) +``` +It supports user-specified `image_transforms` and `patch_transforms` with customisable patch sampling strategies, +which decouples the two-level computations in a multiprocess context. + +### 6. Predefined Datasets for public medical data To quickly get started with popular training data in the medical domain, MONAI provides several data-specific Datasets(like: `MedNISTDataset`, `DecathlonDataset`, etc.), which include downloading from our AWS storage, extracting data files and support generation of training/evaluation items with transforms. And they are flexible that users can easily modify the JSON config file to change the default behaviors. MONAI always welcome new contributions of public datasets, please refer to existing Datasets and leverage the download and extracting APIs, etc. [Public datasets tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/public_datasets.ipynb) indicates how to quickly set up training workflows with `MedNISTDataset` and `DecathlonDataset` and how to create a new `Dataset` for public data. @@ -173,9 +190,15 @@ MONAI always welcome new contributions of public datasets, please refer to exist The common workflow of predefined datasets: ![image](../images/dataset_progress.png) +### 7. Partition dataset for cross validation +The `partition_dataset` utility in MONAI can perform several kinds of mechanism to partition dataset for training and validation or cross-validation. It supports shuffling based on a specified random seed, and will return a set of datasets, each dataset contains one partition. And it can split the dataset based on specified ratios or evenly split into `num_partitions`. For given class labels, it can also make sure the same ratio of classes in every partition. + ## Losses There are domain-specific loss functions in the medical imaging research which are not typically used in the generic computer vision tasks. As an important module of MONAI, these loss functions are implemented in PyTorch, such as `DiceLoss`, `GeneralizedDiceLoss`, `MaskedDiceLoss`, `TverskyLoss` and `FocalLoss`, etc. +## Optimizers +MONAI provides several advanced features in optimizers to help accelerate the training or fine-tuning progress. For example, `Novograd` optimizer can be used to converge obviously faster than traditional optimizers. And users can easily define different learning rates for the model layers based [on the `generate_param_groups` utility API](https://github.com/Project-MONAI/tutorials/blob/master/modules/layer_wise_learning_rate.ipynb). + ## Network architectures Some deep neural network architectures have shown to be particularly effective for medical imaging analysis tasks. MONAI implements reference networks with the aims of both flexibility and code readability. @@ -192,7 +215,7 @@ name, dimension = Conv.CONVTRANS, 3 conv_type = Conv[name, dimension] add_module('conv1', conv_type(in_channels, out_channels, kernel_size=1, bias=False)) ``` -And there are several 1D/2D/3D-compatible implementations of intermediate blocks and generic networks, such as UNet, DynUNet, DenseNet, GAN, AHNet, VNet, SENet(and SEResNet, SEResNeXt), SegResNet, etc. +And there are several 1D/2D/3D-compatible implementations of intermediate blocks and generic networks, such as UNet, DynUNet, DenseNet, GAN, AHNet, VNet, SENet(and SEResNet, SEResNeXt), SegResNet, etc. All the networks can support PyTorch serialization pipeline based on `torch.jit.script`. ## Evaluation To run model inferences and evaluate the model quality, MONAI provides reference implementations for the relevant widely-used approaches. Currently, several popular evaluation metrics and inference patterns are included: @@ -210,13 +233,19 @@ A typical process is: The [Spleen 3D segmentation tutorial](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d.ipynb) leverages `SlidingWindow` inference for validation. ### 2. Metrics for medical tasks -Various useful evaluation metrics have been used to measure the quality of medical image specific models. MONAI already implemented many medical domain-specific metrics, such as: `Mean Dice`, `ROCAUC`, `Confusion Matrices`, `Hausdorff Distance`, `Surface Distance`, etc. +Various useful evaluation metrics have been used to measure the quality of medical image specific models. MONAI already implemented many medical domain-specific metrics, such as: `Mean Dice`, `ROCAUC`, `Confusion Matrices`, `Hausdorff Distance`, `Surface Distance`, `Occlusion Sensitivity`. -For example, `Mean Dice` score can be used for segmentation tasks and the area under the ROC curve(`ROCAUC`) for classification tasks. We continue to integrate more options. +For example, `Mean Dice` score can be used for segmentation tasks, and the area under the ROC curve(`ROCAUC`) for classification tasks. We continue to integrate more options. ## Visualization Beyond the simple point and curve plotting, MONAI provides intuitive interfaces to visualize multidimensional data as GIF animations in TensorBoard. This could provide a quick qualitative assessment of the model by visualizing, for example, the volumetric inputs, segmentation maps, and intermediate feature maps. A runnable example with visualization is available at [UNet training example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_training_dict.py). +And to visualize the class activation mapping for a trained classification model, MONAI provides CAM, GradCAM, GradCAM++ APIs for both 2D and 3D models: + +![image](../images/cam.png) + +The above example is generated by computing [GradCAM/GradCAM++ from a lung CT lesion classification model](https://github.com/Project-MONAI/tutorials/tree/master/modules/interpretability). + ## Result writing Currently MONAI supports writing the model outputs as NIfTI files or PNG files for segmentation tasks, and as CSV files for classification tasks. And the writers can restore the data spacing, orientation or shape according to the `original_shape` or `original_affine` information from the input image. @@ -280,7 +309,8 @@ We also tried to combine AMP with `CacheDataset` and `Novograd` optimizer to ach More details is available at [Fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb). ### 2. Distributed data parallel -Distributed data parallel is an important feature of PyTorch to connect multiple GPU devices on 1 node or several nodes to train or evaluate models. MONAI provides many demos for reference: train/evaluate with PyTorch DDP, train/evaluate with Horovod, train/evaluate with Ignite DDP, partition dataset and train with SmartCacheDataset, etc. And also provides a real world training example based on Decathlon challenge Task01 - Brain Tumor segmentation, it contains distributed caching, training, and validation. We tried to train this example on NVIDIA NGC server, got some performance benchmarks for reference(PyTorch 1.6, CUDA 11, Tesla V100 GPUs): +Distributed data parallel is an important feature of PyTorch to connect multiple GPU devices on single or multiple nodes to train or evaluate models. MONAI provides demos for reference: train/evaluate with PyTorch DDP, train/evaluate with Horovod, train/evaluate with Ignite DDP, partition dataset and train with SmartCacheDataset, as well as a real world training example based on Decathlon challenge Task01 - Brain Tumor segmentation. +The demo contains distributed caching, training, and validation. We tried to train this example on NVIDIA NGC server, got some performance benchmarks for reference(PyTorch 1.6, CUDA 11, Tesla V100 GPUs): ![image](../images/distributed_training.png) ### 3. C++/CUDA optimized modules diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index d3f5a347c7..0bcfbd4240 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -17,21 +17,27 @@ Metrics -------------------------- .. autofunction:: compute_roc_auc -`Confusion Matrix` +`Confusion matrix` ------------------ .. autofunction:: get_confusion_matrix .. autoclass:: ConfusionMatrixMetric :members: -`Hausdorff Distance` +`Hausdorff distance` -------------------- .. autofunction:: compute_hausdorff_distance -`Average Surface Distance` +.. autoclass:: HausdorffDistanceMetric + :members: + +`Average surface distance` -------------------------- .. autofunction:: compute_average_surface_distance +.. autoclass:: SurfaceDistanceMetric + :members: + `Occlusion sensitivity` ----------------------- .. autofunction:: compute_occlusion_sensitivity \ No newline at end of file diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 779a42f9cb..ed17d815b4 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -164,6 +164,11 @@ Layers .. currentmodule:: monai.networks.layers +`ChannelPad` +~~~~~~~~~~~~ +.. autoclass:: ChannelPad + :members: + `SkipConnection` ~~~~~~~~~~~~~~~~ .. autoclass:: SkipConnection @@ -178,6 +183,16 @@ Layers ~~~~~~~~~~~~~~~~ .. autoclass:: GaussianFilter :members: + +`BilateralFilter` +~~~~~~~~~~~~~~~~~ +.. autoclass:: BilateralFilter + :members: + +`HilbertTransform` +~~~~~~~~~~~~~~~~~~ +.. autoclass:: HilbertTransform + :members: `Affine Transform` ~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index f02fb141e4..f7e075f376 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -105,6 +105,12 @@ Crop and Pad :members: :special-members: __call__ +`BoundingRect` +"""""""""""""" +.. autoclass:: BoundingRect + :members: + :special-members: __call__ + Intensity ^^^^^^^^^ @@ -210,6 +216,12 @@ Intensity :members: :special-members: __call__ +`DetectEnvelope` +""""""""""""""""""""" +.. autoclass:: DetectEnvelope + :members: + :special-members: __call__ + IO ^^ @@ -564,6 +576,12 @@ Crop and Pad (Dict) :members: :special-members: __call__ +`BoundingRectd` +""""""""""""""" +.. autoclass:: BoundingRectd + :members: + :special-members: __call__ + Instensity (Dict) ^^^^^^^^^^^^^^^^^ diff --git a/docs/source/utils.rst b/docs/source/utils.rst index 6a03529c30..e0b993da60 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -26,3 +26,8 @@ Misc ---- .. automodule:: monai.utils.misc :members: + +Profiling +--------- +.. automodule:: monai.utils.profiling + :members: diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index 99643fd4db..6272b50b4c 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -26,6 +26,8 @@ from monai.transforms import LoadImaged, Randomizable from monai.utils import ensure_tuple +__all__ = ["MedNISTDataset", "DecathlonDataset", "CrossValidation"] + class MedNISTDataset(Randomizable, CacheDataset): """ @@ -121,7 +123,7 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]: image_class.extend([i] * num_each[i]) num_total = len(image_class) - data = list() + data = [] for i in range(num_total): self.randomize() @@ -302,18 +304,17 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]: def _split_datalist(self, datalist: List[Dict]) -> List[Dict]: if self.section == "test": return datalist - else: - length = len(datalist) - indices = np.arange(length) - self.randomize(indices) + length = len(datalist) + indices = np.arange(length) + self.randomize(indices) - val_length = int(length * self.val_frac) - if self.section == "training": - self.indices = indices[val_length:] - else: - self.indices = indices[:val_length] + val_length = int(length * self.val_frac) + if self.section == "training": + self.indices = indices[val_length:] + else: + self.indices = indices[:val_length] - return [datalist[i] for i in self.indices] + return [datalist[i] for i in self.indices] class CrossValidation: diff --git a/monai/apps/deepgrow/__init__.py b/monai/apps/deepgrow/__init__.py new file mode 100644 index 0000000000..d0044e3563 --- /dev/null +++ b/monai/apps/deepgrow/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/apps/deepgrow/dataset.py b/monai/apps/deepgrow/dataset.py new file mode 100644 index 0000000000..dc4c1a059d --- /dev/null +++ b/monai/apps/deepgrow/dataset.py @@ -0,0 +1,267 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import sys +from typing import Callable, Dict, List, Sequence, Union + +import numpy as np + +from monai.apps.datasets import DecathlonDataset +from monai.transforms import AsChannelFirstd, Compose, GridSampleMode, LoadNiftid, Orientationd, Spacingd + + +# TODO:: Test basic functionality +# TODO:: Unit Test +class DeepgrowDataset(DecathlonDataset): + def __init__( + self, + dimension: int, + pixdim: Sequence[float], + root_dir: str, + task: str, + section: str, + transform: Union[Sequence[Callable], Callable] = (), + download: bool = False, + seed: int = 0, + val_frac: float = 0.2, + cache_num: int = sys.maxsize, + cache_rate: float = 1.0, + num_workers: int = 0, + limit: int = 0, + ) -> None: + self.dimension = dimension + self.pixdim = pixdim + self.limit = limit + + super().__init__( + root_dir=root_dir, + task=task, + section=section, + transform=transform, + download=download, + seed=seed, + val_frac=val_frac, + cache_num=cache_num, + cache_rate=cache_rate, + num_workers=num_workers, + ) + + def _generate_data_list(self, dataset_dir: str) -> List[Dict]: + dataset = super()._generate_data_list(dataset_dir) + + tmp_dataset_dir = dataset_dir + "_{}.deep".format(self.section) + new_datalist = create_dataset( + datalist=dataset, + keys=["image", "label"], + output_dir=tmp_dataset_dir, + dimension=self.dimension, + pixdim=self.pixdim, + limit=self.limit, + relative_path=False, + ) + + dataset_json = os.path.join(tmp_dataset_dir, "dataset.json") + with open(dataset_json, "w") as fp: + json.dump({self.section: new_datalist}, fp, indent=2) + return new_datalist + + +def _get_transforms(keys, pixdim): + mode = [GridSampleMode.BILINEAR, GridSampleMode.NEAREST] if len(keys) == 2 else [GridSampleMode.BILINEAR] + transforms = [ + LoadNiftid(keys=keys), + AsChannelFirstd(keys=keys), + Spacingd(keys=keys, pixdim=pixdim, mode=mode), + Orientationd(keys=keys, axcodes="RAS"), + ] + + return Compose(transforms) + + +def _save_data_2d(vol_idx, data, keys, dataset_dir, relative_path): + vol_image = data[keys[0]] + vol_label = data.get(keys[1]) + data_list = [] + + if len(vol_image.shape) == 4: + logging.info("4D-Image, pick only first series; Image: {}; Label: {}".format(vol_image.shape, vol_label.shape)) + vol_image = vol_image[0] + vol_image = np.moveaxis(vol_image, -1, 0) + + image_count = 0 + label_count = 0 + unique_labels_count = 0 + for sid in range(vol_image.shape[0]): + image = vol_image[sid, ...] + label = vol_label[sid, ...] if vol_label is not None else None + + if vol_label is not None and np.sum(label) == 0: + continue + + image_file_prefix = "vol_idx_{:0>4d}_slice_{:0>3d}".format(vol_idx, sid) + image_file = os.path.join(dataset_dir, "images", image_file_prefix) + image_file += ".npy" + + os.makedirs(os.path.join(dataset_dir, "images"), exist_ok=True) + np.save(image_file, image) + image_count += 1 + + # Test Data + if vol_label is None: + data_list.append( + { + "image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file, + } + ) + continue + + # For all Labels + unique_labels = np.unique(label.flatten()) + unique_labels = unique_labels[unique_labels != 0] + unique_labels_count = max(unique_labels_count, len(unique_labels)) + + for idx in unique_labels: + label_file_prefix = "{}_region_{:0>2d}".format(image_file_prefix, int(idx)) + label_file = os.path.join(dataset_dir, "labels", label_file_prefix) + label_file += ".npy" + + os.makedirs(os.path.join(dataset_dir, "labels"), exist_ok=True) + curr_label = (label == idx).astype(np.float32) + np.save(label_file, curr_label) + + label_count += 1 + data_list.append( + { + "image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file, + "label": label_file.replace(dataset_dir + "/", "") if relative_path else label_file, + "region": int(idx), + } + ) + + print( + "{} => Image: {} => {}; Label: {} => {}; Unique Labels: {}".format( + vol_idx, + vol_image.shape, + image_count, + vol_label.shape if vol_label is not None else None, + label_count, + unique_labels_count, + ) + ) + return data_list + + +def _save_data_3d(vol_idx, data, keys, dataset_dir, relative_path): + vol_image = data[keys[0]] + vol_label = data.get(keys[1]) + data_list = [] + + if len(vol_image.shape) == 4: + logging.info("4D-Image, pick only first series; Image: {}; Label: {}".format(vol_image.shape, vol_label.shape)) + vol_image = vol_image[0] + vol_image = np.moveaxis(vol_image, -1, 0) + + image_count = 0 + label_count = 0 + unique_labels_count = 0 + + image_file_prefix = "vol_idx_{:0>4d}".format(vol_idx) + image_file = os.path.join(dataset_dir, "images", image_file_prefix) + image_file += ".npy" + + os.makedirs(os.path.join(dataset_dir, "images"), exist_ok=True) + np.save(image_file, vol_image) + image_count += 1 + + # Test Data + if vol_label is None: + data_list.append( + { + "image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file, + } + ) + else: + # For all Labels + unique_labels = np.unique(vol_label.flatten()) + unique_labels = unique_labels[unique_labels != 0] + unique_labels_count = max(unique_labels_count, len(unique_labels)) + + for idx in unique_labels: + label_file_prefix = "{}_region_{:0>2d}".format(image_file_prefix, int(idx)) + label_file = os.path.join(dataset_dir, "labels", label_file_prefix) + label_file += ".npy" + + curr_label = (vol_label == idx).astype(np.float32) + os.makedirs(os.path.join(dataset_dir, "labels"), exist_ok=True) + np.save(label_file, curr_label) + + label_count += 1 + data_list.append( + { + "image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file, + "label": label_file.replace(dataset_dir + "/", "") if relative_path else label_file, + "region": int(idx), + } + ) + + print( + "{} => Image: {} => {}; Label: {} => {}; Unique Labels: {}".format( + vol_idx, + vol_image.shape, + image_count, + vol_label.shape if vol_label is not None else None, + label_count, + unique_labels_count, + ) + ) + return data_list + + +def create_dataset( + datalist, output_dir, dimension, pixdim, keys=("image", "label"), base_dir=None, limit=0, relative_path=False +) -> List[Dict]: + if not isinstance(keys, list) and not isinstance(keys, tuple): + keys = [keys] + + transforms = _get_transforms(keys, pixdim) + new_datalist = [] + for idx in range(len(datalist)): + if limit and idx >= limit: + break + + image = datalist[idx][keys[0]] + label = datalist[idx].get(keys[1]) if len(keys) > 1 else None + if base_dir: + image = os.path.join(base_dir, image) + label = os.path.join(base_dir, label) if label else None + + print("{} => {}".format(image, label if label else None)) + if dimension == 2: + data = _save_data_2d( + vol_idx=idx, + data=transforms({"image": image, "label": label}), + keys=("image", "label"), + dataset_dir=output_dir, + relative_path=relative_path, + ) + else: + data = _save_data_3d( + vol_idx=idx, + data=transforms({"image": image, "label": label}), + keys=("image", "label"), + dataset_dir=output_dir, + relative_path=relative_path, + ) + new_datalist.extend(data) + return new_datalist diff --git a/monai/apps/deepgrow/handler.py b/monai/apps/deepgrow/handler.py new file mode 100644 index 0000000000..dbdbbf4289 --- /dev/null +++ b/monai/apps/deepgrow/handler.py @@ -0,0 +1,287 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import statistics + +import numpy as np +import torch +import torch.distributed +from torch.utils.tensorboard import SummaryWriter + +from monai.engines.workflow import Engine, Events +from monai.metrics import compute_meandice +from monai.transforms import rescale_array +from monai.utils import optional_import +from monai.visualize import plot_2d_or_3d_image + +nib, _ = optional_import("nibabel") +torchvision, _ = optional_import("torchvision") +make_grid, _ = optional_import("torchvision.utils", name="make_grid") + +# TODO:: Unit Test + + +class MeanDice: + def __init__(self): + self.data = [] + + def reset(self): + self.data = [] + + def update(self, y_pred, y, batched=True): + if not batched: + y_pred = y_pred[None] + y = y[None] + score = compute_meandice(y_pred=y_pred, y=y, include_background=False).mean() + self.data.append(score.item()) + + def mean(self): + return statistics.mean(self.data) + + def stdev(self): + return statistics.stdev(self.data) if len(self.data) > 1 else 0 + + +class DeepgrowStatsHandler(object): + def __init__( + self, + summary_writer=None, + interval=1, + log_dir="./runs", + tag_name="val_dice", + compute_metric=True, + images=True, + image_interval=1, + max_channels=1, + max_frames=64, + add_scalar=True, + add_stdev=False, + merge_scalar=False, + fold_size=0, + ): + self.writer = SummaryWriter(log_dir=log_dir) if summary_writer is None else summary_writer + self.interval = interval + self.tag_name = tag_name + self.compute_metric = compute_metric + self.images = images + self.image_interval = image_interval + self.max_channels = max_channels + self.max_frames = max_frames + self.add_scalar = add_scalar + self.add_stdev = add_stdev + self.merge_scalar = merge_scalar + self.fold_size = fold_size + + if torch.distributed.is_initialized(): + self.tag_name = "{}-r{}".format(self.tag_name, torch.distributed.get_rank()) + + self.plot_data = {} + self.metric_data = {} + + def attach(self, engine: Engine) -> None: + engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.interval), self, "iteration") + engine.add_event_handler(Events.EPOCH_COMPLETED(every=1), self, "epoch") + + def write_images(self, epoch): + if not self.plot_data or not len(self.plot_data): + return + + all_imgs = [] + titles = [] + for region in sorted(self.plot_data.keys()): + all_imgs.extend(self.plot_data[region]) + metric = self.metric_data.get(region) + dice = "{:.4f}".format(metric.mean()) if self.compute_metric and metric else "" + stdev = "{:.4f}".format(metric.stdev()) if self.compute_metric and metric else "" + titles.extend( + [ + "x({})".format(region), + "y({})".format(region), + "dice: {} +/- {}".format(dice, stdev) if self.compute_metric else "yh({})".format(region), + ] + ) + + if len(all_imgs[0].shape) == 3: + img_tensor = make_grid( + tensor=torch.from_numpy(np.array(all_imgs)), + nrow=3, + normalize=True, + pad_value=2, + ) + self.writer.add_image(tag=f"Deepgrow Regions ({self.tag_name})", img_tensor=img_tensor, global_step=epoch) + + if len(all_imgs[0].shape) == 4: + for region in sorted(self.plot_data.keys()): + tags = [f"region_{region}_image", f"region_{region}_label", f"region_{region}_output"] + for i in range(3): + img = self.plot_data[region][i] + plot_2d_or_3d_image( + img[np.newaxis], epoch, self.writer, 0, self.max_channels, self.max_frames, tags[i] + ) + + logging.info( + "Saved {} Regions {} into Tensorboard at epoch: {}".format( + len(self.plot_data), sorted([*self.plot_data]), epoch + ) + ) + self.writer.flush() + + def write_region_metrics(self, epoch): + metric_sum = 0 + means = {} + stdevs = {} + for region in self.metric_data: + metric = self.metric_data[region].mean() + stdev = self.metric_data[region].stdev() + if self.merge_scalar: + means["{:0>2d}".format(region)] = metric + stdevs["{:0>2d}".format(region)] = stdev + else: + if self.add_stdev: + self.writer.add_scalar("{}_{:0>2d}_mean".format(self.tag_name, region), metric, epoch) + self.writer.add_scalar("{}_{:0>2d}_mean+".format(self.tag_name, region), metric + stdev, epoch) + self.writer.add_scalar("{}_{:0>2d}_mean-".format(self.tag_name, region), metric - stdev, epoch) + else: + self.writer.add_scalar("{}_{:0>2d}".format(self.tag_name, region), metric, epoch) + metric_sum += metric + if self.merge_scalar: + self.writer.add_scalars("{}_region".format(self.tag_name), means, epoch) + + if len(self.metric_data) > 1: + metric_avg = metric_sum / len(self.metric_data) + self.writer.add_scalar("{}_regions_avg".format(self.tag_name), metric_avg, epoch) + self.writer.flush() + + def __call__(self, engine: Engine, action) -> None: + total_steps = engine.state.iteration + if total_steps < engine.state.epoch_length: + total_steps = engine.state.epoch_length * (engine.state.epoch - 1) + total_steps + + if action == "epoch" and not self.fold_size: + epoch = engine.state.epoch + elif self.fold_size and total_steps % self.fold_size == 0: + epoch = int(total_steps / self.fold_size) + else: + epoch = None + + if epoch: + if self.images and epoch % self.image_interval == 0: + self.write_images(epoch) + if self.add_scalar: + self.write_region_metrics(epoch) + + if action == "epoch" or epoch: + self.plot_data = {} + self.metric_data = {} + return + + device = engine.state.device + batch_data = engine.state.batch + output_data = engine.state.output + + for bidx in range(len(batch_data.get("region", []))): + region = batch_data.get("region")[bidx] + region = region.item() if torch.is_tensor(region) else region + + if self.images and self.plot_data.get(region) is None: + self.plot_data[region] = [ + rescale_array(batch_data["image"][bidx][0].detach().cpu().numpy()[np.newaxis], 0, 1), + rescale_array(batch_data["label"][bidx].detach().cpu().numpy(), 0, 1), + rescale_array(output_data["pred"][bidx].detach().cpu().numpy(), 0, 1), + ] + + if self.compute_metric: + if self.metric_data.get(region) is None: + self.metric_data[region] = MeanDice() + self.metric_data[region].update( + y_pred=output_data["pred"][bidx].to(device), y=batch_data["label"][bidx].to(device), batched=False + ) + + +class SegmentationSaver: + def __init__( + self, + output_dir: str = "./runs", + save_np=False, + images=True, + ): + self.output_dir = output_dir + self.save_np = save_np + self.images = images + os.makedirs(self.output_dir, exist_ok=True) + + def attach(self, engine: Engine) -> None: + if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): + engine.add_event_handler(Events.ITERATION_COMPLETED, self) + + def __call__(self, engine: Engine): + batch_data = engine.state.batch + output_data = engine.state.output + device = engine.state.device + tag = "" + if torch.distributed.is_initialized(): + tag = "r{}-".format(torch.distributed.get_rank()) + + for bidx in range(len(batch_data.get("image"))): + step = engine.state.iteration + region = batch_data.get("region")[bidx] + region = region.item() if torch.is_tensor(region) else region + + image = batch_data["image"][bidx][0].detach().cpu().numpy()[np.newaxis] + label = batch_data["label"][bidx].detach().cpu().numpy() + pred = output_data["pred"][bidx].detach().cpu().numpy() + dice = compute_meandice( + y_pred=output_data["pred"][bidx][None].to(device), + y=batch_data["label"][bidx][None].to(device), + include_background=False, + ).mean() + + if self.save_np: + np.savez( + os.path.join( + self.output_dir, + "{}img_label_pred_{}_{:0>4d}_{:0>2d}_{:.4f}".format(tag, region, step, bidx, dice), + ), + image, + label, + pred, + ) + + if self.images and len(image.shape) == 3: + img = make_grid(torch.from_numpy(rescale_array(image, 0, 1)[0])) + lab = make_grid(torch.from_numpy(rescale_array(label, 0, 1)[0])) + + pos = rescale_array(output_data["image"][bidx][1].detach().cpu().numpy()[np.newaxis], 0, 1)[0] + neg = rescale_array(output_data["image"][bidx][2].detach().cpu().numpy()[np.newaxis], 0, 1)[0] + pre = make_grid(torch.from_numpy(np.array([rescale_array(pred, 0, 1)[0], pos, neg]))) + + torchvision.utils.save_image( + tensor=[img, lab, pre], + nrow=3, + pad_value=2, + fp=os.path.join( + self.output_dir, + "{}img_label_pred_{}_{:0>4d}_{:0>2d}_{:.4f}.png".format(tag, region, step, bidx, dice), + ), + ) + + if self.images and len(image.shape) == 4: + samples = {"image": image[0], "label": label[0], "pred": pred[0]} + for sample in samples: + img = nib.Nifti1Image(samples[sample], np.eye(4)) + nib.save( + img, + os.path.join( + self.output_dir, "{}{}_{:0>4d}_{:0>2d}_{:.4f}.nii.gz".format(tag, sample, step, bidx, dice) + ), + ) diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py new file mode 100644 index 0000000000..d29399e2da --- /dev/null +++ b/monai/apps/deepgrow/interaction.py @@ -0,0 +1,66 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import torch +from ignite.engine import Engine, Events +from torch.cuda.amp import autocast + +from monai.engines.utils import CommonKeys + +# TODO:: Unit Test + + +class Interaction: + """ + Deepgrow Training/Evaluation iteration method with interactions (simulation of clicks) support for image and label. + + Args: + transforms: execute additional transformation during every iteration (before train). + Typically, several Tensor based transforms composed by `Compose`. + max_interactions: maximum number of interactions per iteration + train: training or evaluation + key_probability: field name to fill probability for every interaction + """ + + def __init__(self, transforms, max_interactions: int, train: bool, key_probability: str = "probability") -> None: + self.transforms = transforms + self.max_interactions = max_interactions + self.train = train + self.key_probability = key_probability + + def attach(self, engine: Engine) -> None: + engine.add_event_handler(Events.ITERATION_STARTED, self) + + def __call__(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): + if batchdata is None: + raise ValueError("Must provide batch data for current iteration.") + + for j in range(self.max_interactions): + inputs, _ = engine.prepare_batch(batchdata) + inputs = inputs.to(engine.state.device) + + engine.network.eval() + with torch.no_grad(): + if engine.amp: + with autocast(): + predictions = engine.inferer(inputs, engine.network) + else: + predictions = engine.inferer(inputs, engine.network) + + batchdata.update({CommonKeys.PRED: predictions}) + batchdata[self.key_probability] = torch.as_tensor( + ([1.0 - ((1.0 / self.max_interactions) * j)] if self.train else [1.0]) * len(inputs) + ) + batchdata = self.transforms(batchdata) + + return engine._iteration(engine, batchdata) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py new file mode 100644 index 0000000000..f5dcfdd253 --- /dev/null +++ b/monai/apps/deepgrow/transforms.py @@ -0,0 +1,550 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A collection of "vanilla" transforms for spatial operations +https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design +""" +import json +from typing import Optional, Union + +import numpy as np + +from monai.config import KeysCollection +from monai.transforms import InterpolateMode, InterpolateModeSequence, Resize, SpatialCrop +from monai.transforms.compose import MapTransform, Randomizable, Transform +from monai.transforms.utils import generate_spatial_bounding_box +from monai.utils import Sequence, ensure_tuple_rep, min_version, optional_import + +measure, _ = optional_import("skimage.measure", "0.14.2", min_version) +distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") +gaussian_filter, _ = optional_import("scipy.ndimage", name="gaussian_filter") + + +class AddInitialSeedPointd(Randomizable, Transform): + def __init__(self, label="label", guidance="guidance", dimensions=2, connected_regions=6): + self.label = label + self.guidance = guidance + self.dimensions = dimensions + self.connected_regions = connected_regions + + def randomize(self, data=None): + pass + + def _apply(self, label): + label = (label > 0.5).astype(np.float32) + + blobs_labels = measure.label(label.astype(int), background=0) if self.dimensions == 2 else label + assert np.max(blobs_labels) > 0, "Not a valid Label" + + default_guidance = [-1] * (self.dimensions + 1) + pos_guidance = [] + for ridx in range(1, 2 if self.dimensions == 3 else self.connected_regions): + if self.dimensions == 2: + label = (blobs_labels == ridx).astype(np.float32) + if np.sum(label) == 0: + pos_guidance.append(default_guidance) + continue + + distance = distance_transform_cdt(label).flatten() + probability = np.exp(distance) - 1.0 + + idx = np.where(label.flatten() > 0)[0] + seed = np.random.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx])) + dst = distance[seed] + + g = np.asarray(np.unravel_index(seed, label.shape)).transpose().tolist()[0] + g[0] = dst[0] + pos_guidance.append(g) + + return np.asarray([pos_guidance, [default_guidance] * len(pos_guidance)]) + + def __call__(self, data): + data[self.guidance] = self._apply(data[self.label]) + return data + + +class AddGuidanceSignald(Transform): + def __init__(self, image="image", guidance="guidance", sigma=2, dimensions=2, number_intensity_ch=1, batched=False): + self.image = image + self.guidance = guidance + self.sigma = sigma + self.dimensions = dimensions + self.number_intensity_ch = number_intensity_ch + self.batched = batched + + def _get_signal(self, image, guidance): + guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance + if self.dimensions == 3: + signal = np.zeros((len(guidance), image.shape[-3], image.shape[-2], image.shape[-1]), dtype=np.float32) + else: + signal = np.zeros((len(guidance), image.shape[-2], image.shape[-1]), dtype=np.float32) + + for i in range(len(guidance)): + for point in guidance[i]: + if np.any(np.asarray(point) < 0): + continue + + if self.dimensions == 3: + signal[i, int(point[-3]), int(point[-2]), int(point[-1])] = 1.0 + else: + signal[i, int(point[-2]), int(point[-1])] = 1.0 + + if np.max(signal[i]) > 0: + signal[i] = gaussian_filter(signal[i], sigma=self.sigma) + signal[i] = (signal[i] - np.min(signal[i])) / (np.max(signal[i]) - np.min(signal[i])) + return signal + + def _apply(self, image, guidance): + if not self.batched: + signal = self._get_signal(image, guidance) + return np.concatenate([image, signal], axis=0) + + images = [] + for i, g in zip(image, guidance): + i = i[0 : 0 + self.number_intensity_ch, ...] + signal = self._get_signal(i, g) + images.append(np.concatenate([i, signal], axis=0)) + return images + + def __call__(self, data): + image = data[self.image] + guidance = data[self.guidance] + + data[self.image] = self._apply(image, guidance) + return data + + +class FindDiscrepancyRegionsd(Transform): + def __init__(self, label="label", pred="pred", discrepancy="discrepancy", batched=True): + self.label = label + self.pred = pred + self.discrepancy = discrepancy + self.batched = batched + + @staticmethod + def disparity(label, pred): + label = (label > 0.5).astype(np.float32) + pred = (pred > 0.5).astype(np.float32) + disparity = label - pred + + pos_disparity = (disparity > 0).astype(np.float32) + neg_disparity = (disparity < 0).astype(np.float32) + return [pos_disparity, neg_disparity] + + def _apply(self, label, pred): + if not self.batched: + return self.disparity(label, pred) + + disparity = [] + for la, pr in zip(label, pred): + disparity.append(self.disparity(la, pr)) + return disparity + + def __call__(self, data): + label = data[self.label] + pred = data[self.pred] + + data[self.discrepancy] = self._apply(label, pred) + return data + + +class AddRandomGuidanced(Randomizable, Transform): + def __init__( + self, guidance="guidance", discrepancy="discrepancy", probability="probability", dimensions=2, batched=True + ): + self.guidance = guidance + self.discrepancy = discrepancy + self.probability = probability + self.dimensions = dimensions + self.batched = batched + + def randomize(self, data=None): + pass + + @staticmethod + def find_guidance(discrepancy): + distance = distance_transform_cdt(discrepancy).flatten() + probability = np.exp(distance) - 1.0 + idx = np.where(discrepancy.flatten() > 0)[0] + + if np.sum(discrepancy > 0) > 0: + seed = np.random.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx])) + dst = distance[seed] + + g = np.asarray(np.unravel_index(seed, discrepancy.shape)).transpose().tolist()[0] + g[0] = dst[0] + return g + return None + + @staticmethod + def add_guidance(discrepancy, probability): + will_interact = np.random.choice([True, False], p=[probability, 1.0 - probability]) + if not will_interact: + return None, None + + pos_discr = discrepancy[0] + neg_discr = discrepancy[1] + + can_be_positive = np.sum(pos_discr) > 0 + can_be_negative = np.sum(neg_discr) > 0 + correct_pos = np.sum(pos_discr) >= np.sum(neg_discr) + + if correct_pos and can_be_positive: + return AddRandomGuidanced.find_guidance(pos_discr), None + + if not correct_pos and can_be_negative: + return None, AddRandomGuidanced.find_guidance(neg_discr) + return None, None + + def _apply(self, guidance, discrepancy, probability): + guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance + default_guidance = [-1] * (self.dimensions + 1) + + if not self.batched: + pos, neg = self.add_guidance(discrepancy, probability) + if pos: + guidance[0].append(pos) + guidance[1].append(default_guidance) + if neg: + guidance[0].append(default_guidance) + guidance[1].append(neg) + else: + for g, d, p in zip(guidance, discrepancy, probability): + pos, neg = self.add_guidance(d, p) + if pos: + g[0].append(pos) + g[1].append(default_guidance) + if neg: + g[0].append(default_guidance) + g[1].append(neg) + return np.asarray(guidance) + + def __call__(self, data): + guidance = data[self.guidance] + discrepancy = data[self.discrepancy] + probability = data[self.probability] + + data[self.guidance] = self._apply(guidance, discrepancy, probability) + return data + + +class SpatialCropForegroundd(MapTransform): + def __init__( + self, + keys, + source_key: str, + spatial_size, + select_fn=lambda x: x > 0, + channel_indices=None, + margin: int = 0, + meta_key_postfix="meta_dict", + start_coord_key: str = "foreground_start_coord", + end_coord_key: str = "foreground_end_coord", + original_shape_key: str = "foreground_original_shape", + cropped_shape_key: str = "foreground_cropped_shape", + ) -> None: + super().__init__(keys) + + self.source_key = source_key + self.spatial_size = spatial_size + self.select_fn = select_fn + self.channel_indices = channel_indices + self.margin = margin + self.meta_key_postfix = meta_key_postfix + self.start_coord_key = start_coord_key + self.end_coord_key = end_coord_key + self.original_shape_key = original_shape_key + self.cropped_shape_key = cropped_shape_key + + def __call__(self, data): + box_start, box_end = generate_spatial_bounding_box( + data[self.source_key], self.select_fn, self.channel_indices, self.margin + ) + + center = np.mean([box_start, box_end], axis=0).astype(int).tolist() + current_size = np.subtract(box_end, box_start).astype(int).tolist() + + if np.all(np.less(current_size, self.spatial_size)): + cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size) + box_start = cropper.roi_start + box_end = cropper.roi_end + else: + cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) + + for key in self.keys: + meta_key = f"{key}_{self.meta_key_postfix}" + data[meta_key][self.start_coord_key] = box_start + data[meta_key][self.end_coord_key] = box_end + data[meta_key][self.original_shape_key] = data[key].shape + + image = cropper(data[key]) + data[meta_key][self.cropped_shape_key] = image.shape + data[key] = image + return data + + +# Transforms to support Inference +class SpatialCropGuidanced(MapTransform): + def __init__( + self, + keys, + guidance: str, + spatial_size, + spatial_size_key: str = "spatial_size", + meta_key_postfix="meta_dict", + start_coord_key: str = "foreground_start_coord", + end_coord_key: str = "foreground_end_coord", + original_shape_key: str = "foreground_original_shape", + cropped_shape_key: str = "foreground_cropped_shape", + ) -> None: + super().__init__(keys) + + self.guidance = guidance + self.spatial_size = spatial_size + self.spatial_size_key = spatial_size_key + self.meta_key_postfix = meta_key_postfix + self.start_coord_key = start_coord_key + self.end_coord_key = end_coord_key + self.original_shape_key = original_shape_key + self.cropped_shape_key = cropped_shape_key + + def __call__(self, data): + guidance = data[self.guidance] + center = np.mean(guidance[0] + guidance[1], axis=0).astype(int).tolist() + spatial_size = data.get(self.spatial_size_key, self.spatial_size) + + cropper = SpatialCrop(roi_center=center, roi_size=spatial_size) + box_start, box_end = cropper.roi_start, cropper.roi_end + + for key in self.keys: + meta_key = f"{key}_{self.meta_key_postfix}" + data[meta_key][self.start_coord_key] = box_start + data[meta_key][self.end_coord_key] = box_end + data[meta_key][self.original_shape_key] = data[key].shape + + image = cropper(data[key]) + data[meta_key][self.cropped_shape_key] = image.shape + data[key] = image + + pos_clicks, neg_clicks = guidance[0], guidance[1] + pos = np.subtract(pos_clicks, box_start).tolist() if len(pos_clicks) else [] + neg = np.subtract(neg_clicks, box_start).tolist() if len(neg_clicks) else [] + + data[self.guidance] = [pos, neg] + return data + + +class ResizeGuidanced(Transform): + def __init__( + self, + guidance: str, + ref_image, + meta_key_postfix="meta_dict", + cropped_shape_key: str = "foreground_cropped_shape", + ) -> None: + self.guidance = guidance + self.ref_image = ref_image + self.meta_key_postfix = meta_key_postfix + self.cropped_shape_key = cropped_shape_key + + def __call__(self, data): + guidance = data[self.guidance] + meta_dict = data[f"{self.ref_image}_{self.meta_key_postfix}"] + current_shape = data[self.ref_image].shape[1:] + cropped_shape = meta_dict[self.cropped_shape_key][1:] + factor = np.divide(current_shape, cropped_shape) + + pos_clicks, neg_clicks = guidance[0], guidance[1] + pos = np.multiply(pos_clicks, factor).astype(int).tolist() if len(pos_clicks) else [] + neg = np.multiply(neg_clicks, factor).astype(int).tolist() if len(neg_clicks) else [] + + data[self.guidance] = [pos, neg] + return data + + +class RestoreCroppedLabeld(MapTransform): + def __init__( + self, + keys: KeysCollection, + ref_image: str, + slice_only=False, + channel_first=True, + mode: InterpolateModeSequence = InterpolateMode.NEAREST, + align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, + meta_key_postfix: str = "meta_dict", + start_coord_key: str = "foreground_start_coord", + end_coord_key: str = "foreground_end_coord", + original_shape_key: str = "foreground_original_shape", + cropped_shape_key: str = "foreground_cropped_shape", + ) -> None: + super().__init__(keys) + self.ref_image = ref_image + self.slice_only = slice_only + self.channel_first = channel_first + self.mode = ensure_tuple_rep(mode, len(self.keys)) + self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) + self.meta_key_postfix = meta_key_postfix + self.start_coord_key = start_coord_key + self.end_coord_key = end_coord_key + self.original_shape_key = original_shape_key + self.cropped_shape_key = cropped_shape_key + + def __call__(self, data): + meta_dict = data[f"{self.ref_image}_{self.meta_key_postfix}"] + + for idx, key in enumerate(self.keys): + image = data[key] + + # Undo Resize + current_size = image.shape + cropped_size = meta_dict[self.cropped_shape_key] + if np.any(np.not_equal(current_size, cropped_size)): + resizer = Resize(spatial_size=cropped_size[1:], mode=self.mode[idx]) + image = resizer(image, mode=self.mode[idx], align_corners=self.align_corners[idx]) + + # Undo Crop + original_shape = meta_dict[self.original_shape_key] + result = np.zeros(original_shape, dtype=np.float32) + box_start = meta_dict[self.start_coord_key] + box_end = meta_dict[self.end_coord_key] + + sd = min(len(box_start), len(box_end), len(image.shape[1:])) # spatial dims + slices = [slice(None)] + [slice(s, e) for s, e in zip(box_start[:sd], box_end[:sd])] + slices = tuple(slices) + result[slices] = image + + # Undo Spacing + current_size = result.shape[1:] + spatial_shape = np.roll(meta_dict["spatial_shape"], 1).tolist() + spatial_size = spatial_shape[-len(current_size) :] + + if np.any(np.not_equal(current_size, spatial_size)): + resizer = Resize(spatial_size=spatial_size, mode=self.mode[idx]) + result = resizer(result, mode=self.mode[idx], align_corners=self.align_corners[idx]) + + # Undo Slicing + slice_idx = meta_dict.get("slice_idx") + if slice_idx is None or self.slice_only: + final_result = result if len(result.shape) <= 3 else result[0] + else: + slice_idx = meta_dict["slice_idx"][0] + final_result = np.zeros(spatial_shape) + if self.channel_first: + final_result[slice_idx] = result + else: + final_result[..., slice_idx] = result + data[key] = final_result + + meta = data.get(f"{key}_{self.meta_key_postfix}") + if meta is None: + meta = dict() + data[f"{key}_{self.meta_key_postfix}"] = meta + meta["slice_idx"] = slice_idx + meta["affine"] = meta_dict["original_affine"] + return data + + +class AddGuidanceFromPointsd(Randomizable, Transform): + def __init__( + self, + ref_image, + guidance="guidance", + foreground="foreground", + background="background", + axis=0, + channel_first=True, + dimensions=2, + slice_key="slice", + meta_key_postfix: str = "meta_dict", + ): + self.ref_image = ref_image + self.guidance = guidance + self.foreground = foreground + self.background = background + self.axis = axis + self.channel_first = channel_first + self.dimensions = dimensions + self.slice_key = slice_key + self.meta_key_postfix = meta_key_postfix + + def randomize(self, data=None): + pass + + def _apply(self, pos_clicks, neg_clicks, factor, slice_num=None): + points = pos_clicks + points.extend(neg_clicks) + points = np.array(points) + + if self.dimensions == 2: + slices = np.unique(points[:, self.axis]).tolist() + slice_idx = slices[0] if slice_num is None else next(x for x in slices if x == slice_num) + + pos = neg = [] + if len(pos_clicks): + pos_clicks = np.array(pos_clicks) + pos = (pos_clicks[np.where(pos_clicks[:, self.axis] == slice_idx)] * factor)[:, 1:].astype(int).tolist() + if len(neg_clicks): + neg_clicks = np.array(neg_clicks) + neg = (neg_clicks[np.where(neg_clicks[:, self.axis] == slice_idx)] * factor)[:, 1:].astype(int).tolist() + + guidance = [pos, neg, slice_idx, factor] + else: + pos = neg = [] + if len(pos_clicks): + pos = np.multiply(pos_clicks, factor).astype(int).tolist() + if len(neg_clicks): + neg = np.multiply(neg_clicks, factor).astype(int).tolist() + guidance = [pos, neg] + return guidance + + def __call__(self, data): + meta_dict = data[f"{self.ref_image}_{self.meta_key_postfix}"] + original_shape = meta_dict["spatial_shape"] + current_shape = list(data[self.ref_image].shape) + + clicks = [data[self.foreground], data[self.background]] + if self.channel_first: + original_shape = np.roll(original_shape, 1).tolist() + for i in range(len(clicks)): + clicks[i] = json.loads(clicks[i]) if isinstance(clicks[i], str) else clicks[i] + clicks[i] = np.array(clicks[i]).astype(int).tolist() + for j in range(len(clicks[i])): + clicks[i][j] = np.roll(clicks[i][j], 1).tolist() + + factor = np.array(current_shape) / original_shape + + data[self.guidance] = self._apply(clicks[0], clicks[1], factor, data.get(self.slice_key)) + return data + + +class Fetch2DSliced(MapTransform): + def __init__(self, keys, guidance="guidance", axis=0, meta_key_postfix: str = "meta_dict"): + super().__init__(keys) + self.guidance = guidance + self.axis = axis + self.meta_key_postfix = meta_key_postfix + + def _apply(self, image, guidance): + slice_idx = guidance[2] + idx = [] + for i in range(len(image.shape)): + idx.append(slice_idx) if i == self.axis else idx.append(slice(0, image.shape[i])) + + idx = tuple(idx) + return image[idx], idx + + def __call__(self, data): + guidance = data[self.guidance] + for key in self.keys: + img, idx = self._apply(data[key], guidance) + data[key] = img + data[f"{key}_{self.meta_key_postfix}"]["slice_idx"] = idx + return data diff --git a/monai/apps/utils.py b/monai/apps/utils.py index 8461bf4a29..e48dfb63f2 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -31,6 +31,13 @@ else: tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm") +__all__ = [ + "check_hash", + "download_url", + "extractall", + "download_and_extract", +] + def check_hash(filepath: str, val: Optional[str] = None, hash_type: str = "md5") -> bool: """ diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index 1bb2bb7907..c70d495555 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -100,17 +100,9 @@ def set_visible_devices(*dev_inds): os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, dev_inds)) -def get_torch_version_tuple(): - """ - Returns: - tuple of ints represents the pytorch major/minor version. - """ - return tuple((int(x) for x in torch.__version__.split(".")[:2])) - - def _dict_append(in_dict, key, fn): try: - in_dict[key] = fn() + in_dict[key] = fn() if callable(fn) else fn except BaseException: in_dict[key] = "UNKNOWN for given OS" @@ -205,7 +197,7 @@ def get_gpu_info() -> OrderedDict: _dict_append(output, "Current device", lambda: torch.cuda.current_device()) _dict_append(output, "Library compiled for CUDA architectures", lambda: torch.cuda.get_arch_list()) for gpu in range(num_gpus): - _dict_append(output, "Info for GPU", lambda: gpu) + _dict_append(output, "Info for GPU", gpu) gpu_info = torch.cuda.get_device_properties(gpu) _dict_append(output, "\tName", lambda: gpu_info.name) _dict_append(output, "\tIs integrated", lambda: bool(gpu_info.is_integrated)) diff --git a/monai/config/type_definitions.py b/monai/config/type_definitions.py index 9dd75a7e90..ecf08af107 100644 --- a/monai/config/type_definitions.py +++ b/monai/config/type_definitions.py @@ -17,7 +17,7 @@ that should be used consistently throughout the entire MONAI package. A type would be named as type_definitions.KeysCollection -which includes a meaningful name for the concent in the name itself. The +which includes a meaningful name for the consent in the name itself. The definitions in this file map context meaningful names to the underlying object properties that define the expected API. diff --git a/monai/csrc/ext.cpp b/monai/csrc/ext.cpp index 5aaa2e70c9..6740d1b5b4 100644 --- a/monai/csrc/ext.cpp +++ b/monai/csrc/ext.cpp @@ -12,11 +12,16 @@ limitations under the License. */ #include + +#include "filtering/filtering.h" #include "lltm/lltm.h" #include "resample/pushpull.h" #include "utils/resample_utils.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // filtering + m.def("bilateral_filter", &BilateralFilter, "Bilateral Filter"); + // lltm m.def("lltm_forward", &lltm_forward, "LLTM forward"); m.def("lltm_backward", &lltm_backward, "LLTM backward"); diff --git a/monai/csrc/filtering/bilateral/bilateral.h b/monai/csrc/filtering/bilateral/bilateral.h new file mode 100644 index 0000000000..68f8a3093c --- /dev/null +++ b/monai/csrc/filtering/bilateral/bilateral.h @@ -0,0 +1,42 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include +#include "utils/common_utils.h" + +torch::Tensor BilateralFilterCpu(torch::Tensor input, float spatial_sigma, float color_sigma); +torch::Tensor BilateralFilterPHLCpu(torch::Tensor input, float spatial_sigma, float color_sigma); + +#ifdef WITH_CUDA +torch::Tensor BilateralFilterCuda(torch::Tensor input, float spatial_sigma, float color_sigma); +torch::Tensor BilateralFilterPHLCuda(torch::Tensor input, float spatial_sigma, float color_sigma); +#endif + +torch::Tensor BilateralFilter(torch::Tensor input, float spatial_sigma, float color_sigma, bool usePHL) { + torch::Tensor (*filterFunction)(torch::Tensor, float, float); + +#ifdef WITH_CUDA + if (torch::cuda::is_available() && input.is_cuda()) { + CHECK_CONTIGUOUS_CUDA(input); + filterFunction = usePHL ? &BilateralFilterPHLCuda : &BilateralFilterCuda; + } else { + filterFunction = usePHL ? &BilateralFilterPHLCpu : &BilateralFilterCpu; + } +#else + filterFunction = usePHL ? &BilateralFilterPHLCpu : &BilateralFilterCpu; +#endif + + return filterFunction(input, spatial_sigma, color_sigma); +} diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp b/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp new file mode 100644 index 0000000000..cdce729f17 --- /dev/null +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp @@ -0,0 +1,167 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include + +#include "utils/tensor_description.h" + +struct Indexer { + public: + Indexer(int dimensions, int* sizes) { + m_dimensions = dimensions; + m_sizes = sizes; + m_index = new int[dimensions]{0}; + } + + bool operator++(int) { + for (int i = 0; i < m_dimensions; i++) { + m_index[i] += 1; + + if (m_index[i] < m_sizes[i]) { + return true; + } else { + m_index[i] = 0; + } + } + + return false; + } + + int& operator[](int dimensionIndex) { + return m_index[dimensionIndex]; + } + + private: + int m_dimensions; + int* m_sizes; + int* m_index; +}; + +template +void BilateralFilterCpu(torch::Tensor inputTensor, torch::Tensor outputTensor, float spatialSigma, float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + // Raw tensor data pointers. + scalar_t* inputTensorData = inputTensor.data_ptr(); + scalar_t* outputTensorData = outputTensor.data_ptr(); + + // Pre-calculate common values + int windowSize = (int)ceil(5.0f * spatialSigma) | 1; // ORing last bit to ensure odd window size + int halfWindowSize = floor(0.5f * windowSize); + scalar_t spatialExpConstant = -1.0f / (2 * spatialSigma * spatialSigma); + scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + // Kernel sizes. + int* kernelSizes = new int[desc.dimensions]; + + for (int i = 0; i < desc.dimensions; i++) { + kernelSizes[i] = windowSize; + } + + // Pre-calculate gaussian kernel in 1D. + scalar_t* gaussianKernel = new scalar_t[windowSize]; + + for (int i = 0; i < windowSize; i++) { + int distance = i - halfWindowSize; + gaussianKernel[i] = exp(distance * distance * spatialExpConstant); + } + + // Kernel aggregates used to calculate + // the output value. + scalar_t* valueSum = new scalar_t[desc.channelCount]; + scalar_t weightSum = 0; + + // Looping over the batches + for (int b = 0; b < desc.batchCount; b++) { + int batchOffset = b * desc.batchStride; + + // Looping over all dimensions for the home element + Indexer homeIndex = Indexer(desc.dimensions, desc.sizes); + do // while(homeIndex++) + { + // Calculating indexing offset for the home element + int homeOffset = batchOffset; + + for (int i = 0; i < desc.dimensions; i++) { + homeOffset += homeIndex[i] * desc.strides[i]; + } + + // Zero kernel aggregates. + for (int i = 0; i < desc.channelCount; i++) { + valueSum[i] = 0; + } + + weightSum = 0.0f; + + // Looping over all dimensions for the neighbour element + Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes); + do // while(kernelIndex++) + { + // Calculating buffer offset for the neighbour element + // Index is clamped to the border in each dimension. + int neighbourOffset = batchOffset; + + for (int i = 0; i < desc.dimensions; i++) { + int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize; + int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex)); + neighbourOffset += neighbourIndexClamped * desc.strides[i]; + } + + // Euclidean color distance. + scalar_t colorDistanceSquared = 0; + + for (int i = 0; i < desc.channelCount; i++) { + scalar_t diff = inputTensorData[homeOffset + i * desc.channelStride] - + inputTensorData[neighbourOffset + i * desc.channelStride]; + colorDistanceSquared += diff * diff; + } + + // Calculating and combining the spatial + // and color weights. + scalar_t spatialWeight = 1; + + for (int i = 0; i < desc.dimensions; i++) { + spatialWeight *= gaussianKernel[kernelIndex[i]]; + } + + scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. + for (int i = 0; i < desc.channelCount; i++) { + valueSum[i] += inputTensorData[neighbourOffset + i * desc.channelStride] * totalWeight; + } + + weightSum += totalWeight; + } while (kernelIndex++); + + for (int i = 0; i < desc.channelCount; i++) { + outputTensorData[homeOffset + i * desc.channelStride] = valueSum[i] / weightSum; + } + } while (homeIndex++); + } +} + +torch::Tensor BilateralFilterCpu(torch::Tensor inputTensor, float spatialSigma, float colorSigma) { + // Preparing output tensor. + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputTensor.type(), "BilateralFilterCpu", ([&] { + BilateralFilterCpu( + inputTensor, outputTensor, spatialSigma, colorSigma); + })); + + return outputTensor; +} \ No newline at end of file diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp new file mode 100644 index 0000000000..eb94749ea5 --- /dev/null +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp @@ -0,0 +1,89 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include + +#include "filtering/permutohedral/permutohedral.h" +#include "utils/tensor_description.h" + +template +void BilateralFilterPHLCpu( + torch::Tensor inputTensor, + torch::Tensor outputTensor, + float spatialSigma, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + int featureChannels = desc.channelCount + desc.dimensions; + + // Preparing memory + scalar_t* inputTensorData = inputTensor.data_ptr(); + scalar_t* outputTensorData = outputTensor.data_ptr(); + scalar_t* data = new scalar_t[desc.channelStride * desc.channelCount]; + scalar_t* features = new scalar_t[desc.channelStride * featureChannels]; + + // Precalculating inverse sigmas + float invSpatialSigma = 1.0f / spatialSigma; + float invColorSigma = 1.0f / colorSigma; + + // Looping over batches + for (int b = 0; b < desc.batchCount; b++) { + int batchOffset = b * desc.batchStride; + + // Creating features (also permuting input data to be channel last. Permutohedral + // implementation should be changed to channel first to avoid this) + for (int i = 0; i < desc.channelStride; i++) { + // Color features (and permutation) + for (int c = 0; c < desc.channelCount; c++) { + features[i * featureChannels + c] = invColorSigma * inputTensorData[batchOffset + i + c * desc.channelStride]; + data[i * desc.channelCount + c] = inputTensorData[batchOffset + i + c * desc.channelStride]; + } + + // Spatial features + int offsetRemanider = i; + + for (int d = 0; d < desc.dimensions; d++) { + int coord = offsetRemanider / desc.strides[d]; + offsetRemanider -= coord * desc.strides[d]; + + features[i * featureChannels + desc.channelCount + d] = invSpatialSigma * coord; + } + } + + // Filtering data with respect to the features. + scalar_t* output = + PermutohedralCPU(data, features, desc.channelCount, featureChannels, desc.channelStride); + + // Writing output tensor. + for (int i = 0; i < desc.channelStride; i++) { + for (int c = 0; c < desc.channelCount; c++) { + outputTensorData[batchOffset + i + c * desc.channelStride] = output[i * desc.channelCount + c]; + } + } + } + + delete[] data; + delete[] features; +} + +// Function to choose template implementation based on dynamic, channels and dimensions +torch::Tensor BilateralFilterPHLCpu(torch::Tensor inputTensor, float spatialSigma, float colorSigma) { + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + + AT_DISPATCH_FLOATING_TYPES(inputTensor.type(), "BilateralFilterPhlCpu", ([&] { + BilateralFilterPHLCpu(inputTensor, outputTensor, spatialSigma, colorSigma); + })); + + return outputTensor; +} diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu new file mode 100644 index 0000000000..872ff652cb --- /dev/null +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu @@ -0,0 +1,245 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include + +#include "utils/meta_macros.h" +#include "utils/tensor_description.h" + +__constant__ int cBatchStride; +__constant__ int cColorStride; + +__constant__ int cSizes[3]; +__constant__ int cStrides[3]; + +__constant__ int cKernelSize; +__constant__ float cKernel[256]; + +__constant__ float cColorExponentFactor; + +template +__global__ void BilateralFilterCudaKernel1D(scalar_t* input, scalar_t* output) { + int kernelHalfSize = cKernelSize / 2; + + int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; + int batchOffset = blockIdx.y * cBatchStride; + + scalar_t weightSum = 0; + + for (int kernelOffset = 0; kernelOffset < cKernelSize; kernelOffset++) { + int neighbourOffset = max(0, min(homeOffset + (kernelOffset - kernelHalfSize), cSizes[0] - 1)); + scalar_t gaussian = cKernel[kernelOffset]; + + scalar_t distanceSquared = 0; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + homeOffset + c * cColorStride]; + scalar_t b = input[batchOffset + neighbourOffset + c * cColorStride]; + scalar_t diff = a - b; + distanceSquared += diff * diff; + } + + scalar_t spatialWeight = gaussian; + scalar_t colorWeight = exp(cColorExponentFactor * distanceSquared); + scalar_t totalWeight = spatialWeight * colorWeight; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + neighbourOffset + c * cColorStride]; + + output[batchOffset + homeOffset + c * cColorStride] += a * totalWeight; + } + + weightSum += totalWeight; + } + +#pragma unroll + for (int c = 0; c < C; c++) { + output[batchOffset + homeOffset + c * cColorStride] /= weightSum; + } +} + +template +__global__ void BilateralFilterCudaKernel2D(scalar_t* input, scalar_t* output) { + int kernelHalfSize = cKernelSize / 2; + + int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; + int batchOffset = blockIdx.y * cBatchStride; + + int homeX = homeOffset / cStrides[0]; + int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1]; + + scalar_t weightSum = 0; + + for (int kernelX = 0; kernelX < cKernelSize; kernelX++) { + int neighbourX = max(0, min(homeX + (kernelX - kernelHalfSize), cSizes[0] - 1)); + scalar_t gaussianX = cKernel[kernelX]; + + for (int kernelY = 0; kernelY < cKernelSize; kernelY++) { + int neighbourY = max(0, min(homeY + (kernelY - kernelHalfSize), cSizes[1] - 1)); + scalar_t gaussianY = cKernel[kernelY]; + + int neighbourOffset = neighbourX * cStrides[0] + neighbourY; + + scalar_t distanceSquared = 0; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + homeOffset + c * cColorStride]; + scalar_t b = input[batchOffset + neighbourOffset + c * cColorStride]; + scalar_t diff = a - b; + distanceSquared += diff * diff; + } + + scalar_t spatialWeight = gaussianX * gaussianY; + scalar_t colorWeight = exp(cColorExponentFactor * distanceSquared); + scalar_t totalWeight = spatialWeight * colorWeight; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + neighbourOffset + c * cColorStride]; + + output[batchOffset + homeOffset + c * cColorStride] += a * totalWeight; + } + + weightSum += totalWeight; + } + } + +#pragma unroll + for (int c = 0; c < C; c++) { + output[batchOffset + homeOffset + c * cColorStride] /= weightSum; + } +} + +template +__global__ void BilateralFilterCudaKernel3D(scalar_t* input, scalar_t* output) { + int kernelHalfSize = cKernelSize / 2; + + int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; + int batchOffset = blockIdx.y * cBatchStride; + + int homeX = homeOffset / cStrides[0]; + int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1]; + int homeZ = (homeOffset - homeX * cStrides[0] - homeY * cStrides[1]) / cStrides[2]; + + scalar_t weightSum = 0; + + for (int kernelX = 0; kernelX < cKernelSize; kernelX++) { + int neighbourX = max(0, min(homeX + (kernelX - kernelHalfSize), cSizes[0] - 1)); + scalar_t gaussianX = cKernel[kernelX]; + + for (int kernelY = 0; kernelY < cKernelSize; kernelY++) { + int neighbourY = max(0, min(homeY + (kernelY - kernelHalfSize), cSizes[1] - 1)); + scalar_t gaussianY = cKernel[kernelY]; + + for (int kernelZ = 0; kernelZ < cKernelSize; kernelZ++) { + int neighbourZ = max(0, min(homeZ + (kernelZ - kernelHalfSize), cSizes[2] - 1)); + scalar_t gaussianZ = cKernel[kernelZ]; + + int neighbourOffset = neighbourX * cStrides[0] + neighbourY * cStrides[1] + neighbourZ; + + scalar_t distanceSquared = 0; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + homeOffset + c * cColorStride]; + scalar_t b = input[batchOffset + neighbourOffset + c * cColorStride]; + scalar_t diff = a - b; + distanceSquared += diff * diff; + } + + scalar_t spatialWeight = gaussianX * gaussianY * gaussianZ; + scalar_t colorWeight = exp(cColorExponentFactor * distanceSquared); + scalar_t totalWeight = spatialWeight * colorWeight; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + neighbourOffset + c * cColorStride]; + output[batchOffset + homeOffset + c * cColorStride] += a * totalWeight; + } + + weightSum += totalWeight; + } + } + } + +#pragma unroll + for (int c = 0; c < C; c++) { + output[batchOffset + homeOffset + c * cColorStride] /= weightSum; + } +} + +template +void BilateralFilterCuda(torch::Tensor inputTensor, torch::Tensor outputTensor, float spatialSigma, float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + // Pre-calculating exponent factors. + float spatialExponentFactor = -1.0f / (2 * spatialSigma * spatialSigma); + float colorExponentFactor = -1.0f / (2 * colorSigma * colorSigma); + + // Pre-calculating gaussian kernel. + int kernelSize = (int)ceil(5.0f * spatialSigma) | 1; // ORing last bit to ensure odd window size + int kernelHalfSize = floor(0.5f * kernelSize); + float* kernel = new float[kernelSize]; + + for (int i = 0; i < kernelSize; i++) { + int distance = i - kernelHalfSize; + kernel[i] = exp(distance * distance * spatialExponentFactor); + } + + // Writing constant memory. + cudaMemcpyToSymbol(cBatchStride, &desc.batchStride, sizeof(int)); + cudaMemcpyToSymbol(cColorStride, &desc.channelStride, sizeof(int)); + cudaMemcpyToSymbol(cSizes, desc.sizes, sizeof(int) * D); + cudaMemcpyToSymbol(cStrides, desc.strides, sizeof(int) * D); + cudaMemcpyToSymbol(cKernelSize, &kernelSize, sizeof(int)); + cudaMemcpyToSymbol(cKernel, kernel, sizeof(float) * kernelSize); + cudaMemcpyToSymbol(cColorExponentFactor, &colorExponentFactor, sizeof(float)); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + inputTensor.type(), "BilateralFilterCudaKernel", ([&] { + // Dispatch kernel. (Partial template function specialisation not supported at present so using this switch + // instead) + switch (D) { + case (1): + BilateralFilterCudaKernel1D<<>>( + inputTensor.data_ptr(), outputTensor.data_ptr()); + break; + case (2): + BilateralFilterCudaKernel2D<<>>( + inputTensor.data_ptr(), outputTensor.data_ptr()); + break; + case (3): + BilateralFilterCudaKernel3D<<>>( + inputTensor.data_ptr(), outputTensor.data_ptr()); + break; + } + })); + + delete[] kernel; +} + +// Function to choose template implementation based on dynamic, channels and dimensions +torch::Tensor BilateralFilterCuda(torch::Tensor inputTensor, float spatialSigma, float colorSigma) { + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + +#define CASE(c, d) BilateralFilterCuda(inputTensor, outputTensor, spatialSigma, colorSigma); + SWITCH_AB(CASE, 16, 3, inputTensor.size(1), inputTensor.dim() - 2); + + return outputTensor; +} diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu new file mode 100644 index 0000000000..df4ed8771b --- /dev/null +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu @@ -0,0 +1,130 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include + +#include "filtering/permutohedral/permutohedral.h" +#include "utils/meta_macros.h" +#include "utils/tensor_description.h" + +__constant__ int cBatchStride; +__constant__ int cChannelStride; +__constant__ int cSpatialStrides[3]; +__constant__ float cInvSpatialSigma; +__constant__ float cInvColorSigma; + +template +__global__ void FeatureCreation(const scalar_t* inputTensor, scalar_t* outputData, scalar_t* outputFeatures) { + int elementIndex = blockIdx.x * blockDim.x + threadIdx.x; + int batchIndex = blockIdx.y; + + int dataBatchOffset = batchIndex * cBatchStride; + int featureBatchOffset = batchIndex * (D + C) * cChannelStride; + +#pragma unroll + for (int i = 0; i < C; i++) { + outputData[dataBatchOffset + elementIndex * C + i] = + inputTensor[dataBatchOffset + elementIndex + i * cChannelStride]; + outputFeatures[featureBatchOffset + elementIndex * (C + D) + i] = + inputTensor[dataBatchOffset + elementIndex + i * cChannelStride] * cInvColorSigma; + } + + int remainder = elementIndex; + +#pragma unroll + for (int i = 0; i < D; i++) { + int coord = remainder / cSpatialStrides[i]; + remainder -= coord * cSpatialStrides[i]; + + outputFeatures[featureBatchOffset + elementIndex * (C + D) + C + i] = coord * cInvSpatialSigma; + } +} + +template +__global__ void WriteOutput(const scalar_t* data, scalar_t* outputTensor) { + int elementIndex = blockIdx.x * blockDim.x + threadIdx.x; + int batchIndex = blockIdx.y; + int batchOffset = batchIndex * cBatchStride; + +#pragma unroll + for (int i = 0; i < C; i++) { + outputTensor[batchOffset + elementIndex + i * cChannelStride] = data[batchOffset + elementIndex * C + i]; + } +} + +template +void BilateralFilterPHLCuda( + torch::Tensor inputTensor, + torch::Tensor outputTensor, + float spatialSigma, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + int featureChannelCount = desc.channelCount + desc.dimensions; + + // Pre calculating inverse sigmas. + float invSpatialSigma = 1.0f / spatialSigma; + float invColorSigma = 1.0f / colorSigma; + + // Preparing global memory + scalar_t* inputTensorData = inputTensor.data_ptr(); + scalar_t* outputTensorData = outputTensor.data_ptr(); + + scalar_t* data; + scalar_t* features; + cudaMalloc(&data, desc.batchCount * desc.channelStride * desc.channelCount * sizeof(scalar_t)); + cudaMalloc(&features, desc.batchCount * desc.channelStride * featureChannelCount * sizeof(scalar_t)); + + // Prparing constant memory + cudaMemcpyToSymbol(cBatchStride, &desc.batchStride, sizeof(int)); + cudaMemcpyToSymbol(cChannelStride, &desc.channelStride, sizeof(int)); + cudaMemcpyToSymbol(cSpatialStrides, desc.strides, sizeof(int) * desc.dimensions); + cudaMemcpyToSymbol(cInvSpatialSigma, &invSpatialSigma, sizeof(float)); + cudaMemcpyToSymbol(cInvColorSigma, &invColorSigma, sizeof(float)); + + // Creating features + FeatureCreation + <<>>(inputTensorData, data, features); + + // Filtering data with respect to the features for each sample in batch + for (int batchIndex = 0; batchIndex < desc.batchCount; batchIndex++) { + scalar_t* offsetData = data + batchIndex * desc.batchStride; + scalar_t* offsetFeatures = features + batchIndex * featureChannelCount * desc.channelStride; + + PermutohedralCuda(offsetData, offsetFeatures, desc.channelStride, true); + } + + // Writing output + WriteOutput<<>>(data, outputTensorData); + + cudaFree(data); + cudaFree(features); +} + +// Function to choose template implementation based on dynamic, channels and dimensions +torch::Tensor BilateralFilterPHLCuda(torch::Tensor inputTensor, float spatialSigma, float colorSigma) { + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + +#define CASE(c, d) \ + AT_DISPATCH_FLOATING_TYPES(inputTensor.type(), "BilateralFilterCudaPHL", ([&] { \ + BilateralFilterPHLCuda( \ + inputTensor, outputTensor, spatialSigma, colorSigma); \ + })); + + SWITCH_AB(CASE, 16, 3, inputTensor.size(1), inputTensor.dim() - 2); + + return outputTensor; +} diff --git a/monai/csrc/filtering/filtering.h b/monai/csrc/filtering/filtering.h new file mode 100644 index 0000000000..18cf2ae6f4 --- /dev/null +++ b/monai/csrc/filtering/filtering.h @@ -0,0 +1,16 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include "bilateral/bilateral.h" \ No newline at end of file diff --git a/monai/csrc/filtering/permutohedral/hash_table.cu b/monai/csrc/filtering/permutohedral/hash_table.cu new file mode 100644 index 0000000000..cdda0b4fed --- /dev/null +++ b/monai/csrc/filtering/permutohedral/hash_table.cu @@ -0,0 +1,255 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include + +//#define USE_ADDITIVE_HASH + +// turn this on if you want to get slighly less memory consumption and slightly longer run times. +//#define LINEAR_D_MEMORY + +#define USE_CUSTOM_MODULO + +__device__ __constant__ signed short* table_keys; +__device__ __constant__ int* table_entries; +__device__ __constant__ unsigned int table_capacity; +__device__ __constant__ signed short* table_zeros; +__device__ __constant__ char* table_rank; + +/*************************************************************/ +/* Fast computation of modulo operator with constant divisor */ +/*************************************************************/ +__device__ __constant__ unsigned int __div_m; +__device__ __constant__ unsigned int __div_l; +__device__ __constant__ unsigned int __div_c; + +#ifdef USE_CUSTOM_MODULO +__device__ inline unsigned int modHash(unsigned int n) { + unsigned int t1 = __umulhi(__div_m, n); + return n - ((t1 + ((n - t1) >> 1)) >> (__div_l - 1)) * __div_c; +} + +#else +#define modHash(n) ((n) % (2 * table_capacity)); +#endif + +/*************************************************************/ +/* End modulo */ +/*************************************************************/ + +__device__ __constant__ static unsigned int hOffset[64]; + +template +static scalar_t* createHashTable(int capacity) { + scalar_t* values; + cudaMalloc(&values, capacity * vd * sizeof(scalar_t)); + cudaMemset(values, 0, capacity * vd * sizeof(scalar_t)); + + int* entries; + cudaMalloc(&entries, capacity * 2 * sizeof(int)); + cudaMemset(entries, -1, capacity * 2 * sizeof(int)); + + cudaMemcpyToSymbol(table_capacity, &capacity, sizeof(int)); + + cudaMemcpyToSymbol(table_entries, &entries, sizeof(int*)); + +#ifdef LINEAR_D_MEMORY + + char* ranks; + cudaMalloc(&ranks, capacity * sizeof(char)); + + signed short* zeros; + cudaMalloc(&zeros, capacity * sizeof(signed short)); + + cudaMemcpyToSymbol(table_rank, &ranks, sizeof(char*)); + cudaMemcpyToSymbol(table_zeros, &zeros, sizeof(char*)); + +#else + + signed short* keys; + cudaMalloc(&keys, capacity * kd * sizeof(signed short)); + cudaMemset(keys, 0, capacity * kd * sizeof(signed short)); + + cudaMemcpyToSymbol(table_keys, &keys, sizeof(unsigned int*)); + +#endif + + return values; +} + +template +static void destroyHashTable() { +#ifndef LINEAR_D_MEMORY + cudaFree(table_keys); +#endif + cudaFree(table_entries); +} + +template +__device__ __host__ static unsigned int hash(signed short* key) { + unsigned int k = 0; + for (int i = 0; i < kd; i++) { + k += key[i]; + k = k * 2531011; + } + return k; +} + +template +__device__ __host__ static unsigned int hash(int* key) { + unsigned int k = 0; + for (int i = 0; i < kd; i++) { + k += key[i]; + k = k * 2531011; + } + return k; +} + +template +__device__ static bool matchKey(int idx, signed short* key) { + bool match = true; + int slot = idx / (d + 1), color = idx - slot * (d + 1); + char* rank = table_rank + slot * (d + 1); + signed short* zero = table_zeros + slot * (d + 1); + + for (int i = 0; i < d && match; i++) { + match = (key[i] == zero[i] + color - (rank[i] > d - color ? (d + 1) : 0)); + } + + return match; +} + +template +__device__ static void generateKey(int idx, signed short* key) { + int slot = idx / (d + 1), color = idx - slot * (d + 1); + char* rank = table_rank + slot * (d + 1); + signed short* zero = table_zeros + slot * (d + 1); + + for (int i = 0; i < d; i++) { + key[i] = zero[i] + color - (rank[i] > d - color ? (d + 1) : 0); + } +} + +template +__device__ static int hashTableInsert(unsigned int fh, signed short* key, unsigned int slot) { + int h = modHash(fh); + while (1) { + int* e = &table_entries[h]; + + // If the cell is empty (-1), lock it (-2) + int contents = atomicCAS(e, -1, -2); + + if (contents == -2) { + // If it was locked already, move on to the next cell + } else if (contents == -1) { + // If it was empty, we successfully locked it. Write our key. + +#ifndef LINEAR_D_MEMORY + for (int i = 0; i < kd; i++) { + table_keys[slot * kd + i] = key[i]; + } +#endif + + // Unlock + atomicExch(e, slot); + + return h; + } else { +// The cell is unlocked and has a key in it, check if it matches +#ifdef LINEAR_D_MEMORY + if (matchKey(contents, key)) + return h; +#else + bool match = true; + + for (int i = 0; i < kd && match; i++) { + match = (table_keys[contents * kd + i] == key[i]); + } + + if (match) + return h; +#endif + } + // increment the bucket with wraparound + h++; + + if (h == table_capacity * 2) + h = 0; + } +} + +template +__device__ static int hashTableInsert(signed short* key, unsigned int slot) { + unsigned int myHash = hash(key); + return hashTableInsert(myHash, key, slot); +} + +template +__device__ static int hashTableRetrieveWithHash(unsigned int fh, signed short* key) { + int h = modHash(fh); + while (1) { + int* e = table_entries + h; + + if (*e == -1) + return -1; + +#ifdef LINEAR_D_MEMORY + if (matchKey((*e), key)) + return *e; +#else + bool match = true; + + for (int i = 0; i < kd && match; i++) { + match = (table_keys[(*e) * kd + i] == key[i]); + } + + if (match) + return *e; +#endif + + h++; + + if (h == table_capacity * 2) + h = 0; + } +} + +template +__device__ static int hashTableRetrieve(signed short* key) { + int h = modHash(hash(key)); + while (1) { + int* e = table_entries + h; + + if (*e == -1) + return -1; + +#ifdef LINEAR_D_MEMORY + if (matchKey((*e), key)) + return *e; +#else + bool match = true; + + for (int i = 0; i < kd && match; i++) { + match = (table_keys[(*e) * kd + i] == key[i]); + } + + if (match) + return *e; +#endif + + h++; + + if (h == table_capacity * 2) + h = 0; + } +} \ No newline at end of file diff --git a/monai/csrc/filtering/permutohedral/permutohedral.h b/monai/csrc/filtering/permutohedral/permutohedral.h new file mode 100644 index 0000000000..7f57c91a78 --- /dev/null +++ b/monai/csrc/filtering/permutohedral/permutohedral.h @@ -0,0 +1,20 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once +template +scalar_t* PermutohedralCPU(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount); +#ifdef WITH_CUDA +template +void PermutohedralCuda(scalar_t* data, scalar_t* features, int elementCount, bool accurate); +#endif diff --git a/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp b/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp new file mode 100644 index 0000000000..597bf263c1 --- /dev/null +++ b/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp @@ -0,0 +1,516 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +/* +Adapted from https://github.com/abadams/permutohedral +which has the following license... + +MIT License + +Copyright (c) 2020 Andrew Adams + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +#include +#include + +#include + +using namespace std; + +/***************************************************************/ +/* Hash table implementation for permutohedral lattice + * + * The lattice points are stored sparsely using a hash table. + * The key for each point is its spatial location in the (d+1)- + * dimensional space. + */ +/***************************************************************/ +template +class HashTablePermutohedral { + public: + /* Constructor + * kd_: the dimensionality of the position vectors on the hyperplane. + * vd_: the dimensionality of the value vectors + */ + HashTablePermutohedral(int kd_, int vd_) : kd(kd_), vd(vd_) { + capacity = 1 << 15; + filled = 0; + entries = new Entry[capacity]; + keys = new short[kd * capacity / 2]; + values = new scalar_t[vd * capacity / 2]; + memset(values, 0, sizeof(scalar_t) * vd * capacity / 2); + } + + // Returns the number of vectors stored. + int size() { + return filled; + } + + // Returns a pointer to the keys array. + short* getKeys() { + return keys; + } + + // Returns a pointer to the values array. + scalar_t* getValues() { + return values; + } + + /* Returns the index into the hash table for a given key. + * key: a pointer to the position vector. + * h: hash of the position vector. + * create: a flag specifying whether an entry should be created, + * should an entry with the given key not found. + */ + int lookupOffset(short* key, size_t h, bool create = true) { + // Double hash table size if necessary + if (filled >= (capacity / 2) - 1) { + grow(); + } + + // Find the entry with the given key + while (1) { + Entry e = entries[h]; + // check if the cell is empty + if (e.keyIdx == -1) { + if (!create) + return -1; // Return not found. + // need to create an entry. Store the given key. + for (int i = 0; i < kd; i++) + keys[filled * kd + i] = key[i]; + e.keyIdx = filled * kd; + e.valueIdx = filled * vd; + entries[h] = e; + filled++; + return e.valueIdx; + } + + // check if the cell has a matching key + bool match = true; + for (int i = 0; i < kd && match; i++) + match = keys[e.keyIdx + i] == key[i]; + if (match) + return e.valueIdx; + + // increment the bucket with wraparound + h++; + if (h == capacity) + h = 0; + } + } + + /* Looks up the value vector associated with a given key vector. + * k : pointer to the key vector to be looked up. + * create : true if a non-existing key should be created. + */ + scalar_t* lookup(short* k, bool create = true) { + size_t h = hash(k) % capacity; + int offset = lookupOffset(k, h, create); + if (offset < 0) + return NULL; + else + return values + offset; + }; + + /* Hash function used in this implementation. A simple base conversion. */ + size_t hash(const short* key) { + size_t k = 0; + for (int i = 0; i < kd; i++) { + k += key[i]; + k *= 2531011; + } + return k; + } + + private: + /* Grows the size of the hash table */ + void grow() { + size_t oldCapacity = capacity; + capacity *= 2; + + // Migrate the value vectors. + scalar_t* newValues = new scalar_t[vd * capacity / 2]; + memset(newValues, 0, sizeof(scalar_t) * vd * capacity / 2); + memcpy(newValues, values, sizeof(scalar_t) * vd * filled); + delete[] values; + values = newValues; + + // Migrate the key vectors. + short* newKeys = new short[kd * capacity / 2]; + memcpy(newKeys, keys, sizeof(short) * kd * filled); + delete[] keys; + keys = newKeys; + + Entry* newEntries = new Entry[capacity]; + + // Migrate the table of indices. + for (size_t i = 0; i < oldCapacity; i++) { + if (entries[i].keyIdx == -1) + continue; + size_t h = hash(keys + entries[i].keyIdx) % capacity; + while (newEntries[h].keyIdx != -1) { + h++; + if (h == capacity) + h = 0; + } + newEntries[h] = entries[i]; + } + delete[] entries; + entries = newEntries; + } + + // Private struct for the hash table entries. + struct Entry { + Entry() : keyIdx(-1), valueIdx(-1) {} + int keyIdx; + int valueIdx; + }; + + short* keys; + scalar_t* values; + Entry* entries; + size_t capacity, filled; + int kd, vd; +}; + +/***************************************************************/ +/* The algorithm class that performs the filter + * + * PermutohedralLattice::filter(...) does all the work. + * + */ +/***************************************************************/ +template +class PermutohedralLattice { + public: + /* Filters given image against a reference image. + * im : image to be bilateral-filtered. + * ref : reference image whose edges are to be respected. + */ + static scalar_t* filter(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount) { + // Create lattice + PermutohedralLattice lattice(featureChannels, dataChannels + 1, elementCount); + + // Splat into the lattice + scalar_t* col = new scalar_t[dataChannels + 1]; + col[dataChannels] = 1; // homogeneous coordinate + + for (int i = 0, e = 0; e < elementCount; e++) { + for (int c = 0; c < dataChannels; c++, i++) { + col[c] = data[i]; + } + + scalar_t* featureVec = features + e * featureChannels; + lattice.splat(featureVec, col); + } + + // Blur the lattice + lattice.blur(); + + // Slice from the lattice + scalar_t* outputData = new scalar_t[elementCount * dataChannels]; + + lattice.beginSlice(); + + for (int i = 0, e = 0; e < elementCount; e++) { + lattice.slice(col); + + scalar_t scale = 1.0f / col[dataChannels]; + for (int c = 0; c < dataChannels; c++, i++) { + outputData[i] = col[c] * scale; + } + } + + return outputData; + } + + /* Constructor + * d_ : dimensionality of key vectors + * vd_ : dimensionality of value vectors + * nData_ : number of points in the input + */ + PermutohedralLattice(int d_, int vd_, int nData_) : d(d_), vd(vd_), nData(nData_), hashTable(d_, vd_) { + // Allocate storage for various arrays + elevated = new scalar_t[d + 1]; + scaleFactor = new scalar_t[d]; + + greedy = new short[d + 1]; + rank = new char[d + 1]; + barycentric = new scalar_t[d + 2]; + replay = new ReplayEntry[nData * (d + 1)]; + nReplay = 0; + canonical = new short[(d + 1) * (d + 1)]; + key = new short[d + 1]; + + // compute the coordinates of the canonical simplex, in which + // the difference between a contained point and the zero + // remainder vertex is always in ascending order. (See pg.4 of paper.) + for (int i = 0; i <= d; i++) { + for (int j = 0; j <= d - i; j++) + canonical[i * (d + 1) + j] = i; + for (int j = d - i + 1; j <= d; j++) + canonical[i * (d + 1) + j] = i - (d + 1); + } + + // Compute parts of the rotation matrix E. (See pg.4-5 of paper.) + for (int i = 0; i < d; i++) { + // the diagonal entries for normalization + scaleFactor[i] = 1.0f / (sqrtf((scalar_t)(i + 1) * (i + 2))); + + /* We presume that the user would like to do a Gaussian blur of standard deviation + * 1 in each dimension (or a total variance of d, summed over dimensions.) + * Because the total variance of the blur performed by this algorithm is not d, + * we must scale the space to offset this. + * + * The total variance of the algorithm is (See pg.6 and 10 of paper): + * [variance of splatting] + [variance of blurring] + [variance of splatting] + * = d(d+1)(d+1)/12 + d(d+1)(d+1)/2 + d(d+1)(d+1)/12 + * = 2d(d+1)(d+1)/3. + * + * So we need to scale the space by (d+1)sqrt(2/3). + */ + scaleFactor[i] *= (d + 1) * sqrtf(2.0 / 3); + } + } + + /* Performs splatting with given position and value vectors */ + void splat(scalar_t* position, scalar_t* value) { + // first rotate position into the (d+1)-dimensional hyperplane + elevated[d] = -d * position[d - 1] * scaleFactor[d - 1]; + for (int i = d - 1; i > 0; i--) + elevated[i] = + (elevated[i + 1] - i * position[i - 1] * scaleFactor[i - 1] + (i + 2) * position[i] * scaleFactor[i]); + elevated[0] = elevated[1] + 2 * position[0] * scaleFactor[0]; + + // prepare to find the closest lattice points + scalar_t scale = 1.0f / (d + 1); + char* myrank = rank; + short* mygreedy = greedy; + + // greedily search for the closest zero-colored lattice point + int sum = 0; + for (int i = 0; i <= d; i++) { + scalar_t v = elevated[i] * scale; + scalar_t up = ceilf(v) * (d + 1); + scalar_t down = floorf(v) * (d + 1); + + if (up - elevated[i] < elevated[i] - down) + mygreedy[i] = (short)up; + else + mygreedy[i] = (short)down; + + sum += mygreedy[i]; + } + sum /= d + 1; + + // rank differential to find the permutation between this simplex and the canonical one. + // (See pg. 3-4 in paper.) + memset(myrank, 0, sizeof(char) * (d + 1)); + for (int i = 0; i < d; i++) + for (int j = i + 1; j <= d; j++) + if (elevated[i] - mygreedy[i] < elevated[j] - mygreedy[j]) + myrank[i]++; + else + myrank[j]++; + + if (sum > 0) { + // sum too large - the point is off the hyperplane. + // need to bring down the ones with the smallest differential + for (int i = 0; i <= d; i++) { + if (myrank[i] >= d + 1 - sum) { + mygreedy[i] -= d + 1; + myrank[i] += sum - (d + 1); + } else + myrank[i] += sum; + } + } else if (sum < 0) { + // sum too small - the point is off the hyperplane + // need to bring up the ones with largest differential + for (int i = 0; i <= d; i++) { + if (myrank[i] < -sum) { + mygreedy[i] += d + 1; + myrank[i] += (d + 1) + sum; + } else + myrank[i] += sum; + } + } + + // Compute barycentric coordinates (See pg.10 of paper.) + memset(barycentric, 0, sizeof(scalar_t) * (d + 2)); + for (int i = 0; i <= d; i++) { + barycentric[d - myrank[i]] += (elevated[i] - mygreedy[i]) * scale; + barycentric[d + 1 - myrank[i]] -= (elevated[i] - mygreedy[i]) * scale; + } + barycentric[0] += 1.0f + barycentric[d + 1]; + + // Splat the value into each vertex of the simplex, with barycentric weights. + for (int remainder = 0; remainder <= d; remainder++) { + // Compute the location of the lattice point explicitly (all but the last coordinate - it's redundant because they + // sum to zero) + for (int i = 0; i < d; i++) + key[i] = mygreedy[i] + canonical[remainder * (d + 1) + myrank[i]]; + + // Retrieve pointer to the value at this vertex. + scalar_t* val = hashTable.lookup(key, true); + + // Accumulate values with barycentric weight. + for (int i = 0; i < vd; i++) + val[i] += barycentric[remainder] * value[i]; + + // Record this interaction to use later when slicing + replay[nReplay].offset = val - hashTable.getValues(); + replay[nReplay].weight = barycentric[remainder]; + nReplay++; + } + } + + // Prepare for slicing + void beginSlice() { + nReplay = 0; + } + + /* Performs slicing out of position vectors. Note that the barycentric weights and the simplex + * containing each position vector were calculated and stored in the splatting step. + * We may reuse this to accelerate the algorithm. (See pg. 6 in paper.) + */ + void slice(scalar_t* col) { + scalar_t* base = hashTable.getValues(); + for (int j = 0; j < vd; j++) + col[j] = 0; + for (int i = 0; i <= d; i++) { + ReplayEntry r = replay[nReplay++]; + for (int j = 0; j < vd; j++) { + col[j] += r.weight * base[r.offset + j]; + } + } + } + + /* Performs a Gaussian blur along each projected axis in the hyperplane. */ + void blur() { + // Prepare arrays + short* neighbor1 = new short[d + 1]; + short* neighbor2 = new short[d + 1]; + scalar_t* newValue = new scalar_t[vd * hashTable.size()]; + scalar_t* oldValue = hashTable.getValues(); + scalar_t* hashTableBase = oldValue; + + scalar_t* zero = new scalar_t[vd]; + for (int k = 0; k < vd; k++) + zero[k] = 0; + + // For each of d+1 axes, + for (int j = 0; j <= d; j++) { + // For each vertex in the lattice, + for (int i = 0; i < hashTable.size(); i++) { // blur point i in dimension j + short* key = hashTable.getKeys() + i * (d); // keys to current vertex + for (int k = 0; k < d; k++) { + neighbor1[k] = key[k] + 1; + neighbor2[k] = key[k] - 1; + } + neighbor1[j] = key[j] - d; + neighbor2[j] = key[j] + d; // keys to the neighbors along the given axis. + + scalar_t* oldVal = oldValue + i * vd; + scalar_t* newVal = newValue + i * vd; + + scalar_t *vm1, *vp1; + + vm1 = hashTable.lookup(neighbor1, false); // look up first neighbor + if (vm1) + vm1 = vm1 - hashTableBase + oldValue; + else + vm1 = zero; + + vp1 = hashTable.lookup(neighbor2, false); // look up second neighbor + if (vp1) + vp1 = vp1 - hashTableBase + oldValue; + else + vp1 = zero; + + // Mix values of the three vertices + for (int k = 0; k < vd; k++) + newVal[k] = (0.25f * vm1[k] + 0.5f * oldVal[k] + 0.25f * vp1[k]); + } + scalar_t* tmp = newValue; + newValue = oldValue; + oldValue = tmp; + // the freshest data is now in oldValue, and newValue is ready to be written over + } + + // depending where we ended up, we may have to copy data + if (oldValue != hashTableBase) { + memcpy(hashTableBase, oldValue, hashTable.size() * vd * sizeof(scalar_t)); + delete oldValue; + } else { + delete newValue; + } + + delete zero; + delete neighbor1; + delete neighbor2; + } + + private: + int d, vd, nData; + scalar_t *elevated, *scaleFactor, *barycentric; + short* canonical; + short* key; + + // slicing is done by replaying splatting (ie storing the sparse matrix) + struct ReplayEntry { + int offset; + scalar_t weight; + } * replay; + int nReplay, nReplaySub; + + public: + char* rank; + short* greedy; + HashTablePermutohedral hashTable; +}; + +template +scalar_t* PermutohedralCPU( + scalar_t* data, + scalar_t* features, + int dataChannels, + int featureChannels, + int elementCount) { + return PermutohedralLattice::filter(data, features, dataChannels, featureChannels, elementCount); +} + +template float* PermutohedralCPU(float* data, float* features, int dataChannels, int featureChannels, int elementCount); +template double* PermutohedralCPU( + double* data, + double* features, + int dataChannels, + int featureChannels, + int elementCount); \ No newline at end of file diff --git a/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu b/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu new file mode 100644 index 0000000000..c60d0d8c31 --- /dev/null +++ b/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu @@ -0,0 +1,537 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +/* +Adapted from https://github.com/abadams/permutohedral +which has the following license... + +MIT License + +Copyright (c) 2020 Andrew Adams + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +#define BLOCK_SIZE 64 + +#include +#include +#include +#include +#include + +#include "hash_table.cu" +#include "utils/meta_macros.h" + +template +struct MatrixEntry { + int index; + scalar_t weight; +}; + +template +__global__ static void createMatrix( + const int elementCount, + const scalar_t* positions, + const scalar_t* values, + const scalar_t* scaleFactor, + MatrixEntry* matrix) { + const int threadId = threadIdx.x; + const int idx = threadIdx.x + blockIdx.x * BLOCK_SIZE; + const bool outOfBounds = idx >= elementCount; + + scalar_t myElevated[pd + 1]; + const scalar_t* myPosition = positions + idx * pd; + + int myGreedy[pd + 1]; + int myRank[pd + 1]; + + scalar_t myBarycentric[pd + 2]; + __shared__ short keys[pd * BLOCK_SIZE]; + short* myKey = keys + threadId * pd; + + if (!outOfBounds) { + myElevated[pd] = -pd * myPosition[pd - 1] * scaleFactor[pd - 1]; + + for (int i = pd - 1; i > 0; i--) { + myElevated[i] = + myElevated[i + 1] - i * (myPosition[i - 1]) * scaleFactor[i - 1] + (i + 2) * myPosition[i] * scaleFactor[i]; + } + + myElevated[0] = myElevated[1] + 2 * myPosition[0] * scaleFactor[0]; + + // find the closest zero-colored lattice point + + // greedily search for the closest zero-colored lattice point + signed short sum = 0; + + for (int i = 0; i <= pd; i++) { + scalar_t v = myElevated[i] * (1.0f / (pd + 1)); + scalar_t up = ceilf(v) * (pd + 1); + scalar_t down = floorf(v) * (pd + 1); + + myGreedy[i] = (signed short)(up - myElevated[i] < myElevated[i] - down ? up : down); + sum += myGreedy[i]; + } + + sum /= pd + 1; + + // sort differential to find the permutation between this simplex and the canonical one + for (int i = 0; i <= pd; i++) { + myRank[i] = 0; + + for (int j = 0; j <= pd; j++) { + scalar_t iDiff = myElevated[i] - myGreedy[i]; + scalar_t jDiff = myElevated[j] - myGreedy[j]; + + if (iDiff < jDiff || (iDiff == jDiff && i > j)) { + myRank[i]++; + } + } + } + + if (sum > 0) // sum too large, need to bring down the ones with the smallest differential + { + for (int i = 0; i <= pd; i++) { + if (myRank[i] >= pd + 1 - sum) { + myGreedy[i] -= (pd + 1); + myRank[i] += sum - (pd + 1); + } else { + myRank[i] += sum; + } + } + } else if (sum < 0) // sum too small, need to bring up the ones with largest differential + { + for (int i = 0; i <= pd; i++) { + if (myRank[i] < -sum) { + myGreedy[i] += (pd + 1); + myRank[i] += sum + (pd + 1); + } else { + myRank[i] += sum; + } + } + } + +#ifdef LINEAR_D_MEMORY + for (int i = 0; i <= pd; i++) { + table_zeros[idx * (pd + 1) + i] = myGreedy[i]; + table_rank[idx * (pd + 1) + i] = myRank[i]; + } +#endif + + // turn delta into barycentric coords + for (int i = 0; i <= pd + 1; i++) { + myBarycentric[i] = 0; + } + + for (int i = 0; i <= pd; i++) { + scalar_t delta = (myElevated[i] - myGreedy[i]) * (1.0f / (pd + 1)); + myBarycentric[pd - myRank[i]] += delta; + myBarycentric[pd + 1 - myRank[i]] -= delta; + } + + myBarycentric[0] += 1.0f + myBarycentric[pd + 1]; + } + +#ifdef USE_ADDITIVE_HASH + unsigned int cumulative_hash = hash(myGreedy); +#endif + + for (int color = 0; color <= pd; color++) { + // Compute the location of the lattice point explicitly (all but + // the last coordinate - it's redundant because they sum to zero) + if (!outOfBounds) { + for (int i = 0; i < pd; i++) { + myKey[i] = myGreedy[i] + color; + + if (myRank[i] > pd - color) { + myKey[i] -= (pd + 1); + } + } + } + +#ifdef USE_ADDITIVE_HASH + for (int i = 0; i < pd; i++) { + if (myRank[i] == pd - color) { + cumulative_hash += hOffset[i]; + } + } +#endif + + if (!outOfBounds) { + MatrixEntry r; + +#ifdef USE_ADDITIVE_HASH + r.index = hashTableInsert(cumulative_hash, myKey, idx * (pd + 1) + color); +#else + r.index = hashTableInsert(myKey, idx * (pd + 1) + color); +#endif + + r.weight = myBarycentric[color]; + matrix[idx * (pd + 1) + color] = r; + } + } +} + +template +__global__ static void cleanHashTable(const int elementCount, MatrixEntry* matrix) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + + if (idx >= elementCount) + return; + + // find my hash table entry + int* e = table_entries + idx; + + // Check if I created my own key in the previous phase + if (*e >= 0) { + // Rehash my key and reset the pointer in order to merge with + // any other pixel that created a different entry under the + // same key. If the computation was serial this would never + // happen, but sometimes race conditions can make the same key + // be inserted twice. hashTableRetrieve always returns the + // earlier, so it's no problem as long as we rehash now. + +#ifdef LINEAR_D_MEMORY + // Get my key + short myKey[kd]; + generateKey(*e, myKey); + *e = hashTableRetrieve(myKey); +#else + *e = hashTableRetrieve(table_keys + *e * kd); +#endif + } +} + +template +__global__ static void splat( + const int elementCount, + scalar_t* values, + MatrixEntry* matrix, + scalar_t* table_values) { + const int color = threadIdx.y; + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + + const bool outOfBounds = idx >= elementCount; + + if (outOfBounds) { + return; + } + + scalar_t* myValue = values + idx * vd; + + MatrixEntry r = matrix[idx * (pd + 1) + color]; + + matrix[idx * (pd + 1) + color].index = r.index = table_entries[r.index]; + scalar_t* val = table_values + r.index * (vd + 1); + + for (int j = 0; j < vd; j++) { + gpuAtomicAdd(val + j, myValue[j] * r.weight); + } + + gpuAtomicAdd(val + vd, r.weight); +} + +// splat splits by color, so extend the y coordinate to our blocks to represent that +// dim3 oldblocks((w-1)/8+1, (h-1)/8+1, 1); +// dim3 oldblockSize(8, 8, 1); +// oldblocks.y *= pd+1; +// splatCache<<>>(w, h, values, matrix); + +// int blockCount = (elementCount + 1) / BLOCK_SIZE + 1; +// int blockSize = BLOCK_SIZE; + +// splatCache<<>>(elementCount, values, matrix); + +template +__global__ static void splatCache( + const int elementCount, + scalar_t* values, + MatrixEntry* matrix, + scalar_t* table_values) { + // const int x = threadIdx.x + blockIdx.x * blockDim.x; + // const int y = threadIdx.y + (blockIdx.y/(pd+1)) * blockDim.y; + + // const int threadId = threadIdx.y*blockDim.x + threadIdx.x; + // const int color = blockIdx.y % (pd+1); + // const int idx = y*w + x; + + const int threadId = threadIdx.x; + const int color = threadIdx.y; + const int idx = threadIdx.x + blockIdx.x * BLOCK_SIZE; + + const bool outOfBounds = idx >= elementCount; + + __shared__ int sharedOffsets[BLOCK_SIZE]; + __shared__ scalar_t sharedValues[BLOCK_SIZE * (vd + 1)]; + + int myOffset = -1; + scalar_t* myValue = sharedValues + threadId * (vd + 1); + + if (!outOfBounds) { + scalar_t* value = values + idx * vd; + + MatrixEntry r = matrix[idx * (pd + 1) + color]; + + // convert the matrix entry from a pointer into the entries array to a pointer into the keys/values array + matrix[idx * (pd + 1) + color].index = r.index = table_entries[r.index]; + // record the offset into the keys/values array in shared space + myOffset = sharedOffsets[threadId] = r.index * (vd + 1); + + for (int j = 0; j < vd; j++) { + myValue[j] = value[j] * r.weight; + } + myValue[vd] = r.weight; + + } else { + sharedOffsets[threadId] = -1; + } + + __syncthreads(); + + // am I the first thread in this block to care about this key? + + if (outOfBounds) + return; + + for (int i = 0; i < BLOCK_SIZE; i++) { + if (i < threadId) { + if (myOffset == sharedOffsets[i]) { + // somebody else with higher priority cares about this key + return; + } + } else if (i > threadId) { + if (myOffset == sharedOffsets[i]) { + // someone else with lower priority cares about this key, accumulate it into mine + for (int j = 0; j <= vd; j++) { + sharedValues[threadId * (vd + 1) + j] += sharedValues[i * (vd + 1) + j]; + } + } + } + } + + // only the threads with something to write to main memory are still going + scalar_t* val = table_values + myOffset; + for (int j = 0; j <= vd; j++) { + gpuAtomicAdd(val + j, myValue[j]); + } +} + +template +__global__ static void blur( + int n, + scalar_t* newValues, + MatrixEntry* matrix, + int color, + scalar_t* table_values) { + const int idx = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x * blockDim.y + threadIdx.x; + + if (idx >= n) + return; + + // Check if I'm valid + if (matrix[idx].index != idx) + return; + + // find my key and the keys of my neighbours + short myKey[pd + 1]; + short np[pd + 1]; + short nm[pd + 1]; + +#ifdef LINEAR_D_MEMORY + generateKey(idx, myKey); + for (int i = 0; i < pd; i++) { + np[i] = myKey[i] + 1; + nm[i] = myKey[i] - 1; + } +#else + for (int i = 0; i < pd; i++) { + myKey[i] = table_keys[idx * pd + i]; + np[i] = myKey[i] + 1; + nm[i] = myKey[i] - 1; + } +#endif + + np[color] -= pd + 1; + nm[color] += pd + 1; + +#ifdef USE_ADDITIVE_HASH + unsigned int hCurrent = hash(myKey); + int offNp = hashTableRetrieveWithHash(hCurrent + hOffset[color], np); + int offNm = hashTableRetrieveWithHash(hCurrent - hOffset[color], nm); +#else + int offNp = hashTableRetrieve(np); + int offNm = hashTableRetrieve(nm); +#endif + + scalar_t* valMe = table_values + (vd + 1) * idx; + scalar_t* valNp = table_values + (vd + 1) * offNp; + scalar_t* valNm = table_values + (vd + 1) * offNm; + scalar_t* valOut = newValues + (vd + 1) * idx; + + if (offNp >= 0 && offNm >= 0) { + for (int i = 0; i <= vd; i++) { + valOut[i] = (valNp[i] + (valMe[i] * 2) + valNm[i]) / 4; + } + } else if (offNp >= 0) { + for (int i = 0; i <= vd; i++) { + valOut[i] = (valNp[i] + (valMe[i] * 2)) / 4; + } + } else if (offNm >= 0) { + for (int i = 0; i <= vd; i++) { + valOut[i] = (valNm[i] + (valMe[i] * 2)) / 4; + } + } else { + for (int i = 0; i <= vd; i++) { + valOut[i] = valMe[i] * 2; + } + } +} + +template +__global__ static void slice( + const int elementCount, + scalar_t* values, + MatrixEntry* matrix, + scalar_t* table_values) { + const int threadId = threadIdx.x; + const int idx = threadIdx.x + blockIdx.x * BLOCK_SIZE; + const bool outOfBounds = idx >= elementCount; + + if (outOfBounds) + return; + + __shared__ scalar_t localValue[BLOCK_SIZE * vd]; + + scalar_t* myValue = localValue + threadId * vd; + scalar_t myWeight = 0; + + for (int i = 0; i < vd; i++) { + myValue[i] = 0; + } + + for (int i = 0; i <= pd; i++) { + MatrixEntry r = matrix[idx * (pd + 1) + i]; + scalar_t* val = table_values + r.index * (vd + 1); + + for (int j = 0; j < vd; j++) { + myValue[j] += r.weight * val[j]; + } + + myWeight += r.weight * val[vd]; + } + + myWeight = 1.0f / myWeight; + + for (int j = 0; j < vd; j++) { + values[idx * vd + j] = myValue[j] * myWeight; + } +} + +template +void PermutohedralCuda(scalar_t* values, scalar_t* positions, int elementCount, bool accurate) { + scalar_t blurVariance = accurate ? 0.5 : 0; + + scalar_t* scaleFactor; + cudaMalloc(&scaleFactor, pd * sizeof(scalar_t)); + + scalar_t scaleFactorHost[pd]; + for (int i = 0; i < pd; i++) { + scaleFactorHost[i] = (pd + 1) * sqrtf((1.0 / 6 + blurVariance) / ((i + 1) * (i + 2))); + } + + cudaMemcpy(scaleFactor, scaleFactorHost, pd * sizeof(scalar_t), cudaMemcpyHostToDevice); + + MatrixEntry* matrix; + cudaMalloc(&matrix, elementCount * (pd + 1) * sizeof(MatrixEntry)); + + scalar_t* table_values = createHashTable(elementCount * (pd + 1)); + + // Populate constant memory for hash helpers + unsigned long long int __host_two32 = ((unsigned long long int)1) << 32; + unsigned int __host_div_c = 2 * (elementCount * (pd + 1)); + unsigned int __host_div_l = ceilf(logf((float)__host_div_c) / logf(2.0f)); + unsigned int __host_div_m = (__host_two32 << __host_div_l) / __host_div_c - __host_two32 + 1; + cudaMemcpyToSymbol(__div_c, &__host_div_c, sizeof(unsigned int)); + cudaMemcpyToSymbol(__div_l, &__host_div_l, sizeof(unsigned int)); + cudaMemcpyToSymbol(__div_m, &__host_div_m, sizeof(unsigned int)); + + // Populate constant memory with hash of offset vectors + unsigned int hOffset_host[pd + 1]; + signed short offset[pd + 1]; + for (int i = 0; i < pd; offset[i] = 1, i++) + ; + for (int i = 0; i <= pd; i++) { + offset[i] -= pd + 1; + hOffset_host[i] = hash(offset); + offset[i] += pd + 1; + } + cudaMemcpyToSymbol(hOffset, &hOffset_host, sizeof(unsigned int) * (pd + 1)); + + int blockCount = (elementCount + 1) / BLOCK_SIZE + 1; + int blockSize = BLOCK_SIZE; + + createMatrix<<>>(elementCount, positions, values, scaleFactor, matrix); + + // fix duplicate hash table entries + int tableSize = elementCount * 2 * (pd + 1); + int cleanBlockSize = 32; + int cleanBlocks = (tableSize - 1) / cleanBlockSize + 1; + + cleanHashTable<<>>(tableSize, matrix); + + splat<<>>(elementCount, values, matrix, table_values); + + if (accurate) { + scalar_t* newValues; + cudaMalloc(&newValues, elementCount * (pd + 1) * (vd + 1) * sizeof(scalar_t)); + cudaMemset(newValues, 0, elementCount * (pd + 1) * (vd + 1) * sizeof(scalar_t)); + + for (int color = 0; color <= pd; color++) { + blur + <<>>(elementCount * (pd + 1), newValues, matrix, color, table_values); + + scalar_t* swap = newValues; + newValues = table_values; + table_values = swap; + } + + cudaFree(newValues); + } + + slice<<>>(elementCount, values, matrix, table_values); + + destroyHashTable(); + cudaFree(table_values); +} + +#define DECLARATION(dc, fc) \ + template void PermutohedralCuda(float* values, float* positions, int elementCount, bool accurate); \ + template void PermutohedralCuda(double* values, double* positions, int elementCount, bool accurate); +DO_FOR_AB(DECLARATION, 16, 19) diff --git a/monai/csrc/utils/meta_macros.h b/monai/csrc/utils/meta_macros.h new file mode 100644 index 0000000000..73d1851198 --- /dev/null +++ b/monai/csrc/utils/meta_macros.h @@ -0,0 +1,131 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +// Helper Macros: for internal use (see below) +#define _DO_1(TARGET) TARGET(1) +#define _DO_2(TARGET) TARGET(2) _DO_1(TARGET) +#define _DO_3(TARGET) TARGET(3) _DO_2(TARGET) +#define _DO_4(TARGET) TARGET(4) _DO_3(TARGET) +#define _DO_5(TARGET) TARGET(5) _DO_4(TARGET) +#define _DO_6(TARGET) TARGET(6) _DO_5(TARGET) +#define _DO_7(TARGET) TARGET(7) _DO_6(TARGET) +#define _DO_8(TARGET) TARGET(8) _DO_7(TARGET) +#define _DO_9(TARGET) TARGET(9) _DO_8(TARGET) +#define _DO_10(TARGET) TARGET(10) _DO_9(TARGET) +#define _DO_11(TARGET) TARGET(11) _DO_10(TARGET) +#define _DO_12(TARGET) TARGET(12) _DO_11(TARGET) +#define _DO_13(TARGET) TARGET(13) _DO_12(TARGET) +#define _DO_14(TARGET) TARGET(14) _DO_13(TARGET) +#define _DO_15(TARGET) TARGET(15) _DO_14(TARGET) +#define _DO_16(TARGET) TARGET(16) _DO_15(TARGET) +#define _DO_17(TARGET) TARGET(17) _DO_16(TARGET) +#define _DO_18(TARGET) TARGET(18) _DO_17(TARGET) +#define _DO_19(TARGET) TARGET(19) _DO_18(TARGET) +#define _DO_20(TARGET) TARGET(20) _DO_19(TARGET) +#define _DO_21(TARGET) TARGET(21) _DO_20(TARGET) +#define _DO_22(TARGET) TARGET(22) _DO_21(TARGET) +#define _DO_23(TARGET) TARGET(23) _DO_22(TARGET) +#define _DO_24(TARGET) TARGET(24) _DO_23(TARGET) +#define _DO_25(TARGET) TARGET(25) _DO_24(TARGET) +#define _DO_26(TARGET) TARGET(26) _DO_25(TARGET) +#define _DO_27(TARGET) TARGET(27) _DO_26(TARGET) +#define _DO_28(TARGET) TARGET(28) _DO_27(TARGET) +#define _DO_29(TARGET) TARGET(29) _DO_28(TARGET) +#define _DO_30(TARGET) TARGET(30) _DO_29(TARGET) +#define _DO_31(TARGET) TARGET(31) _DO_30(TARGET) +#define _DO_32(TARGET) TARGET(32) _DO_31(TARGET) + +#define _DO_A_1(TARGET, A) TARGET(A, 1) +#define _DO_A_2(TARGET, A) TARGET(A, 2) _DO_A_1(TARGET, A) +#define _DO_A_3(TARGET, A) TARGET(A, 3) _DO_A_2(TARGET, A) +#define _DO_A_4(TARGET, A) TARGET(A, 4) _DO_A_3(TARGET, A) +#define _DO_A_5(TARGET, A) TARGET(A, 5) _DO_A_4(TARGET, A) +#define _DO_A_6(TARGET, A) TARGET(A, 6) _DO_A_5(TARGET, A) +#define _DO_A_7(TARGET, A) TARGET(A, 7) _DO_A_6(TARGET, A) +#define _DO_A_8(TARGET, A) TARGET(A, 8) _DO_A_7(TARGET, A) +#define _DO_A_9(TARGET, A) TARGET(A, 9) _DO_A_8(TARGET, A) +#define _DO_A_10(TARGET, A) TARGET(A, 10) _DO_A_9(TARGET, A) +#define _DO_A_11(TARGET, A) TARGET(A, 11) _DO_A_10(TARGET, A) +#define _DO_A_12(TARGET, A) TARGET(A, 12) _DO_A_11(TARGET, A) +#define _DO_A_13(TARGET, A) TARGET(A, 13) _DO_A_12(TARGET, A) +#define _DO_A_14(TARGET, A) TARGET(A, 14) _DO_A_13(TARGET, A) +#define _DO_A_15(TARGET, A) TARGET(A, 15) _DO_A_14(TARGET, A) +#define _DO_A_16(TARGET, A) TARGET(A, 16) _DO_A_15(TARGET, A) +#define _DO_A_17(TARGET, A) TARGET(A, 17) _DO_A_16(TARGET, A) +#define _DO_A_18(TARGET, A) TARGET(A, 18) _DO_A_17(TARGET, A) +#define _DO_A_19(TARGET, A) TARGET(A, 19) _DO_A_18(TARGET, A) +#define _DO_A_20(TARGET, A) TARGET(A, 20) _DO_A_19(TARGET, A) +#define _DO_A_21(TARGET, A) TARGET(A, 21) _DO_A_20(TARGET, A) +#define _DO_A_22(TARGET, A) TARGET(A, 22) _DO_A_21(TARGET, A) +#define _DO_A_23(TARGET, A) TARGET(A, 23) _DO_A_22(TARGET, A) +#define _DO_A_24(TARGET, A) TARGET(A, 24) _DO_A_23(TARGET, A) +#define _DO_A_25(TARGET, A) TARGET(A, 25) _DO_A_24(TARGET, A) +#define _DO_A_26(TARGET, A) TARGET(A, 26) _DO_A_25(TARGET, A) +#define _DO_A_27(TARGET, A) TARGET(A, 27) _DO_A_26(TARGET, A) +#define _DO_A_28(TARGET, A) TARGET(A, 28) _DO_A_27(TARGET, A) +#define _DO_A_29(TARGET, A) TARGET(A, 29) _DO_A_28(TARGET, A) +#define _DO_A_30(TARGET, A) TARGET(A, 30) _DO_A_29(TARGET, A) +#define _DO_A_31(TARGET, A) TARGET(A, 31) _DO_A_30(TARGET, A) +#define _DO_A_32(TARGET, A) TARGET(A, 32) _DO_A_31(TARGET, A) + +#define _DO_1_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 1) +#define _DO_2_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 2) _DO_1_B(TARGET, B_RANGE) +#define _DO_3_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 3) _DO_2_B(TARGET, B_RANGE) +#define _DO_4_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 4) _DO_3_B(TARGET, B_RANGE) +#define _DO_5_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 5) _DO_4_B(TARGET, B_RANGE) +#define _DO_6_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 6) _DO_5_B(TARGET, B_RANGE) +#define _DO_7_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 7) _DO_6_B(TARGET, B_RANGE) +#define _DO_8_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 8) _DO_7_B(TARGET, B_RANGE) +#define _DO_9_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 9) _DO_8_B(TARGET, B_RANGE) +#define _DO_10_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 10) _DO_9_B(TARGET, B_RANGE) +#define _DO_11_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 11) _DO_10_B(TARGET, B_RANGE) +#define _DO_12_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 12) _DO_11_B(TARGET, B_RANGE) +#define _DO_13_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 13) _DO_12_B(TARGET, B_RANGE) +#define _DO_14_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 14) _DO_13_B(TARGET, B_RANGE) +#define _DO_15_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 15) _DO_14_B(TARGET, B_RANGE) +#define _DO_16_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 16) _DO_15_B(TARGET, B_RANGE) +#define _DO_17_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 17) _DO_16_B(TARGET, B_RANGE) +#define _DO_18_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 18) _DO_17_B(TARGET, B_RANGE) +#define _DO_19_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 19) _DO_18_B(TARGET, B_RANGE) +#define _DO_20_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 20) _DO_19_B(TARGET, B_RANGE) +#define _DO_21_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 21) _DO_20_B(TARGET, B_RANGE) +#define _DO_22_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 22) _DO_21_B(TARGET, B_RANGE) +#define _DO_23_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 23) _DO_22_B(TARGET, B_RANGE) +#define _DO_24_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 24) _DO_23_B(TARGET, B_RANGE) +#define _DO_25_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 25) _DO_24_B(TARGET, B_RANGE) +#define _DO_26_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 26) _DO_25_B(TARGET, B_RANGE) +#define _DO_27_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 27) _DO_26_B(TARGET, B_RANGE) +#define _DO_28_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 28) _DO_27_B(TARGET, B_RANGE) +#define _DO_29_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 29) _DO_28_B(TARGET, B_RANGE) +#define _DO_30_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 30) _DO_29_B(TARGET, B_RANGE) +#define _DO_31_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 31) _DO_30_B(TARGET, B_RANGE) +#define _DO_32_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 32) _DO_31_B(TARGET, B_RANGE) + +#define _CASE_A(A) \ + case (A): \ + CASE(A) break; +#define _CASE_AB(A, B) \ + case (A * 100 + B): \ + CASE(A, B) break; + +// Preproccessor For Loops +#define DO_FOR_A(TARGET, A_RANGE) _DO_##A_RANGE(TARGET) +#define DO_FOR_AB(TARGET, A_RANGE, B_RANGE) _DO_##A_RANGE##_B(TARGET, B_RANGE) + +// Preproccessor Switch Statement Generators +#define SWITCH_A(CASE, A_RANGE, A) \ + switch (A) { DO_FOR_A(_CASE_A, A_RANGE) } +#define SWITCH_AB(CALL, A_RANGE, B_RANGE, A, B) \ + switch (A * 100 + B) { DO_FOR_AB(_CASE_AB, A_RANGE, B_RANGE) } diff --git a/monai/csrc/utils/tensor_description.h b/monai/csrc/utils/tensor_description.h new file mode 100644 index 0000000000..6072037f72 --- /dev/null +++ b/monai/csrc/utils/tensor_description.h @@ -0,0 +1,40 @@ + +#include + +// Struct to easily cache descriptive information about a tensor. +// This is helpful as regular calls to the size and stride member +// functions of tensors appear to cause memory issues. +struct TensorDescription { + public: + TensorDescription(torch::Tensor tensor) { + batchCount = tensor.size(0); + batchStride = tensor.stride(0); + + channelCount = tensor.size(1); + channelStride = tensor.stride(1); + + dimensions = tensor.dim() - 2; + sizes = new int[dimensions]; + strides = new int[dimensions]; + + for (int i = 0; i < dimensions; i++) { + sizes[i] = tensor.size(i + 2); + strides[i] = tensor.stride(i + 2); + } + } + + ~TensorDescription() { + delete[] sizes; + delete[] strides; + } + + int batchCount; + int batchStride; + + int channelCount; + int channelStride; + + int dimensions; + int* sizes; + int* strides; +}; diff --git a/monai/data/dataset.py b/monai/data/dataset.py index bf2d22f838..892546b2a4 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -321,6 +321,7 @@ def __init__( cache_dir: Union[Path, str] = "cache", hash_func: Callable[..., bytes] = pickle_hashing, db_name: str = "monai_cache", + progress: bool = True, pickle_protocol=pickle.HIGHEST_PROTOCOL, lmdb_kwargs: Optional[dict] = None, ) -> None: @@ -338,6 +339,7 @@ def __init__( hash_func: a callable to compute hash from data items to be cached. defaults to `monai.data.utils.pickle_hashing`. db_name: lmdb database file name. Defaults to "monai_cache". + progress: whether to display a progress bar. pickle_protocol: pickle protocol version. Defaults to pickle.HIGHEST_PROTOCOL. https://docs.python.org/3/library/pickle.html#pickle-protocols lmdb_kwargs: additional keyword arguments to the lmdb environment. @@ -352,15 +354,16 @@ def __init__( if not self.lmdb_kwargs.get("map_size", 0): self.lmdb_kwargs["map_size"] = 1024 ** 4 # default map_size self._read_env = None + self.progress = progress + print(f"Accessing lmdb file: {self.db_file.absolute()}.") def _fill_cache_start_reader(self): # create cache - print(f"Accessing lmdb file: {self.db_file.absolute()}.") self.lmdb_kwargs["readonly"] = False env = lmdb.open(path=f"{self.db_file}", subdir=False, **self.lmdb_kwargs) if not has_tqdm: warnings.warn("LMDBDataset: tqdm is not installed. not displaying the caching progress.") - for item in tqdm(self.data) if has_tqdm else self.data: + for item in tqdm(self.data) if has_tqdm and self.progress else self.data: key = self.hash_func(item) done, retry, val = False, 5, None while not done and retry > 0: diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 6c243f2b08..925772433b 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -28,12 +28,16 @@ from itk import Image # type: ignore from nibabel.nifti1 import Nifti1Image from PIL import Image as PILImage + + has_itk = has_nib = has_pil = True else: - itk, _ = optional_import("itk", allow_namespace_pkg=True) + itk, has_itk = optional_import("itk", allow_namespace_pkg=True) Image, _ = optional_import("itk", allow_namespace_pkg=True, name="Image") - nib, _ = optional_import("nibabel") + nib, has_nib = optional_import("nibabel") Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image") - PILImage, _ = optional_import("PIL.Image") + PILImage, has_pil = optional_import("PIL.Image") + +__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader"] class ImageReader(ABC): @@ -121,7 +125,7 @@ class ITKReader(ImageReader): def __init__(self, **kwargs): super().__init__() self.kwargs = kwargs - if int(itk.Version.GetITKMajorVersion()) == 5 and int(itk.Version.GetITKMinorVersion()) < 2: + if has_itk and int(itk.Version.GetITKMajorVersion()) == 5 and int(itk.Version.GetITKMinorVersion()) < 2: # warning the ITK LazyLoading mechanism was not threadsafe until version 5.2.0, # requesting access to the itk.imread function triggers the lazy loading of the relevant itk modules # before the parallel use of the function. @@ -136,7 +140,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: if a list of files, verify all the suffixes. """ - return True + return has_itk def read(self, data: Union[Sequence[str], str], **kwargs): """ @@ -307,7 +311,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: """ suffixes: Sequence[str] = ["nii", "nii.gz"] - return is_supported_format(filename, suffixes) + return has_nib and is_supported_format(filename, suffixes) def read(self, data: Union[Sequence[str], str], **kwargs): """ @@ -521,8 +525,8 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: filename: file name or a list of file names to read. if a list of files, verify all the suffixes. """ - suffixes: Sequence[str] = ["png", "jpg", "bmp"] - return is_supported_format(filename, suffixes) + suffixes: Sequence[str] = ["png", "jpg", "jpeg", "bmp"] + return has_pil and is_supported_format(filename, suffixes) def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs): """ diff --git a/monai/data/thread_buffer.py b/monai/data/thread_buffer.py index b76e8c7444..9832a7c164 100644 --- a/monai/data/thread_buffer.py +++ b/monai/data/thread_buffer.py @@ -22,7 +22,7 @@ class ThreadBuffer: One issue raised by using a thread in this way is that during the lifetime of the thread the source object is being iterated over, so if the thread hasn't finished another attempt to iterate over it will raise an exception or yield - inexpected results. To ensure the thread releases the iteration and proper cleanup is done the stop() method must + unexpected results. To ensure the thread releases the iteration and proper cleanup is done the stop() method must be called which will join with the thread. Args: diff --git a/monai/data/utils.py b/monai/data/utils.py index eaeeee543b..c5fcbf3c86 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -134,8 +134,7 @@ def dense_patch_slices( dim_starts.append(start_idx) starts.append(dim_starts) out = np.asarray([x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T - slices = [tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out] - return slices + return [tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out] def iter_patch( @@ -550,7 +549,7 @@ def is_supported_format(filename: Union[Sequence[str], str], suffixes: Sequence[ filenames: Sequence[str] = ensure_tuple(filename) for name in filenames: tokens: Sequence[str] = PurePath(name).suffixes - if len(tokens) == 0 or not any(("." + s.lower()) in "".join(tokens) for s in suffixes): + if len(tokens) == 0 or all("." + s.lower() not in "".join(tokens) for s in suffixes): return False return True @@ -567,7 +566,7 @@ def partition_dataset( ): """ Split the dataset into N partitions. It can support shuffle based on specified random seed. - Will return a set of datasets, every dataset contains 1 partion of original dataset. + Will return a set of datasets, every dataset contains 1 partition of original dataset. And it can split the dataset based on specified ratios or evenly split into `num_partitions`. Refer to: https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py. @@ -598,7 +597,7 @@ def partition_dataset( """ data_len = len(data) - datasets = list() + datasets = [] indices = list(range(data_len)) if shuffle: @@ -682,7 +681,7 @@ def partition_dataset_classes( """ if not classes or len(classes) != len(data): raise ValueError(f"length of classes {classes} must match the dataset length {len(data)}.") - datasets = list() + datasets = [] class_indices = defaultdict(list) for i, c in enumerate(classes): class_indices[c].append(i) @@ -698,7 +697,7 @@ def partition_dataset_classes( drop_last=drop_last, even_divisible=even_divisible, ) - if len(class_partition_indices) == 0: + if not class_partition_indices: class_partition_indices = per_class_partition_indices else: for part, data_indices in zip(class_partition_indices, per_class_partition_indices): @@ -735,8 +734,7 @@ def select_cross_validation_folds(partitions: Sequence[Iterable], folds: Union[S >>> select_cross_validation_folds(partitions, [-1, 2]) [9, 10, 5, 6] """ - data_list = [data_item for fold_id in ensure_tuple(folds) for data_item in partitions[fold_id]] - return data_list + return [data_item for fold_id in ensure_tuple(folds) for data_item in partitions[fold_id]] class DistributedSampler(_TorchDistributedSampler): diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index 13835f915b..a519210466 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -12,3 +12,4 @@ from .evaluator import * from .multi_gpu_supervised_trainer import * from .trainer import * +from .utils import * diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 930747edfb..306be5f2db 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -28,6 +28,8 @@ Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") +__all__ = ["Evaluator", "SupervisedEvaluator", "EnsembleEvaluator"] + class Evaluator(Workflow): """ diff --git a/monai/engines/multi_gpu_supervised_trainer.py b/monai/engines/multi_gpu_supervised_trainer.py index 7110a09c0f..33268308e5 100644 --- a/monai/engines/multi_gpu_supervised_trainer.py +++ b/monai/engines/multi_gpu_supervised_trainer.py @@ -29,6 +29,11 @@ Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") +__all__ = [ + "create_multigpu_supervised_trainer", + "create_multigpu_supervised_evaluator", +] + def _default_transform(_x: torch.Tensor, _y: torch.Tensor, _y_pred: torch.Tensor, loss: torch.Tensor) -> float: return loss.item() diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index c625d1b669..7ab3a47eba 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -29,6 +29,8 @@ Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") +__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer"] + class Trainer(Workflow): """ @@ -280,7 +282,7 @@ def _iteration( if batchdata is None: raise ValueError("must provide batch data for current iteration.") - d_input = self.prepare_batch(batchdata, engine.state.device) + d_input = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) batch_size = self.data_loader.batch_size g_input = self.g_prepare_batch(batch_size, self.latent_shape, engine.state.device, engine.non_blocking) g_output = self.g_inferer(g_input, self.g_network) diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 9becd5c5f6..37715cad52 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -13,6 +13,7 @@ from .checkpoint_saver import CheckpointSaver from .classification_saver import ClassificationSaver from .confusion_matrix import ConfusionMatrix +from .hausdorff_distance import HausdorffDistance from .lr_schedule_handler import LrScheduleHandler from .mean_dice import MeanDice from .metric_logger import MetricLogger @@ -20,6 +21,7 @@ from .segmentation_saver import SegmentationSaver from .smartcache_handler import SmartCacheHandler from .stats_handler import StatsHandler +from .surface_distance import SurfaceDistance from .tensorboard_handlers import TensorBoardImageHandler, TensorBoardStatsHandler from .utils import * from .validation_handler import ValidationHandler diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py index 57d8728cd4..0cc05b2dc4 100644 --- a/monai/handlers/checkpoint_saver.py +++ b/monai/handlers/checkpoint_saver.py @@ -188,6 +188,13 @@ def attach(self, engine: Engine) -> None: else: engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.save_interval), self.interval_completed) + def _delete_previous_final_ckpt(self): + saved = self._final_checkpoint._saved + if len(saved) > 0: + item = saved.pop(0) + self._final_checkpoint.save_handler.remove(item.filename) + self.logger.info(f"Deleted previous saved final checkpoint: {item.filename}") + def completed(self, engine: Engine) -> None: """Callback for train or validation/evaluation completed Event. Save final checkpoint if configure save_final is True. @@ -196,6 +203,8 @@ def completed(self, engine: Engine) -> None: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ assert callable(self._final_checkpoint), "Error: _final_checkpoint function not specified." + # delete previous saved final checkpoint if existing + self._delete_previous_final_ckpt() self._final_checkpoint(engine) assert self.logger is not None assert hasattr(self.logger, "info"), "Error, provided logger has not info attribute." @@ -211,6 +220,8 @@ def exception_raised(self, engine: Engine, e: Exception) -> None: e: the exception caught in Ignite during engine.run(). """ assert callable(self._final_checkpoint), "Error: _final_checkpoint function not specified." + # delete previous saved final checkpoint if existing + self._delete_previous_final_ckpt() self._final_checkpoint(engine) assert self.logger is not None assert hasattr(self.logger, "info"), "Error, provided logger has not info attribute." diff --git a/monai/handlers/hausdorff_distance.py b/monai/handlers/hausdorff_distance.py new file mode 100644 index 0000000000..56b8b341ff --- /dev/null +++ b/monai/handlers/hausdorff_distance.py @@ -0,0 +1,99 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Sequence + +import torch + +from monai.metrics import HausdorffDistanceMetric +from monai.utils import MetricReduction, exact_version, optional_import + +NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError") +Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") +reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced") +sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "sync_all_reduce") + + +class HausdorffDistance(Metric): # type: ignore[valid-type, misc] # due to optional_import + """ + Computes Hausdorff distance from full size Tensor and collects average over batch, class-channels, iterations. + """ + + def __init__( + self, + include_background: bool = False, + distance_metric: str = "euclidean", + percentile: Optional[float] = None, + directed: bool = False, + output_transform: Callable = lambda x: x, + device: Optional[torch.device] = None, + ) -> None: + """ + + Args: + include_background: whether to include distance computation on the first channel of the predicted output. + Defaults to ``False``. + distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] + the metric used to compute surface distance. Defaults to ``"euclidean"``. + percentile: an optional float number between 0 and 100. If specified, the corresponding + percentile of the Hausdorff Distance rather than the maximum result will be achieved. + Defaults to ``None``. + directed: whether to calculate directed Hausdorff distance. Defaults to ``False``. + output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. + device: device specification in case of distributed computation usage. + + """ + super().__init__(output_transform, device=device) + self.hd = HausdorffDistanceMetric( + include_background=include_background, + distance_metric=distance_metric, + percentile=percentile, + directed=directed, + reduction=MetricReduction.MEAN, + ) + self._sum = 0.0 + self._num_examples = 0 + + @reinit__is_reduced + def reset(self) -> None: + self._sum = 0.0 + self._num_examples = 0 + + @reinit__is_reduced + def update(self, output: Sequence[torch.Tensor]) -> None: + """ + Args: + output: sequence with contents [y_pred, y]. + + Raises: + ValueError: When ``output`` length is not 2. The metric can only support y_pred and y. + + """ + if len(output) != 2: + raise ValueError(f"output must have length 2, got {len(output)}.") + y_pred, y = output + score, not_nans = self.hd(y_pred, y) + not_nans = int(not_nans.item()) + + # add all items in current batch + self._sum += score.item() * not_nans + self._num_examples += not_nans + + @sync_all_reduce("_sum", "_num_examples") + def compute(self) -> float: + """ + Raises: + NotComputableError: When ``compute`` is called before an ``update`` occurs. + + """ + if self._num_examples == 0: + raise NotComputableError("HausdorffDistance must have at least one example before it can be computed.") + return self._sum / self._num_examples diff --git a/monai/handlers/surface_distance.py b/monai/handlers/surface_distance.py new file mode 100644 index 0000000000..b35089423c --- /dev/null +++ b/monai/handlers/surface_distance.py @@ -0,0 +1,95 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Sequence + +import torch + +from monai.metrics import SurfaceDistanceMetric +from monai.utils import MetricReduction, exact_version, optional_import + +NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError") +Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric") +reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced") +sync_all_reduce, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "sync_all_reduce") + + +class SurfaceDistance(Metric): # type: ignore[valid-type, misc] # due to optional_import + """ + Computes surface distance from full size Tensor and collects average over batch, class-channels, iterations. + """ + + def __init__( + self, + include_background: bool = False, + symmetric: bool = False, + distance_metric: str = "euclidean", + output_transform: Callable = lambda x: x, + device: Optional[torch.device] = None, + ) -> None: + """ + + Args: + include_background: whether to include distance computation on the first channel of the predicted output. + Defaults to ``False``. + symmetric: whether to calculate the symmetric average surface distance between + `seg_pred` and `seg_gt`. Defaults to ``False``. + distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] + the metric used to compute surface distance. Defaults to ``"euclidean"``. + output_transform: transform the ignite.engine.state.output into [y_pred, y] pair. + device: device specification in case of distributed computation usage. + + """ + super().__init__(output_transform, device=device) + self.hd = SurfaceDistanceMetric( + include_background=include_background, + symmetric=symmetric, + distance_metric=distance_metric, + reduction=MetricReduction.MEAN, + ) + self._sum = 0.0 + self._num_examples = 0 + + @reinit__is_reduced + def reset(self) -> None: + self._sum = 0.0 + self._num_examples = 0 + + @reinit__is_reduced + def update(self, output: Sequence[torch.Tensor]) -> None: + """ + Args: + output: sequence with contents [y_pred, y]. + + Raises: + ValueError: When ``output`` length is not 2. The metric can only support y_pred and y. + + """ + if len(output) != 2: + raise ValueError(f"output must have length 2, got {len(output)}.") + y_pred, y = output + score, not_nans = self.hd(y_pred, y) + not_nans = int(not_nans.item()) + + # add all items in current batch + self._sum += score.item() * not_nans + self._num_examples += not_nans + + @sync_all_reduce("_sum", "_num_examples") + def compute(self) -> float: + """ + Raises: + NotComputableError: When ``compute`` is called before an ``update`` occurs. + + """ + if self._num_examples == 0: + raise NotComputableError("SurfaceDistance must have at least one example before it can be computed.") + return self._sum / self._num_examples diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index e401e18b0c..e96521f47e 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -21,6 +21,8 @@ else: Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") +__all__ = ["stopping_fn_from_metric", "stopping_fn_from_loss", "all_gather"] + def stopping_fn_from_metric(metric_name: str) -> Callable[[Engine], Any]: """ diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index eea56d3d45..36cc3de478 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -17,6 +17,8 @@ from monai.inferers.utils import sliding_window_inference from monai.utils import BlendMode, PytorchPadMode +__all__ = ["Inferer", "SimpleInferer", "SlidingWindowInferer"] + class Inferer(ABC): """ diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 48bd334061..c7db520cb2 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -17,6 +17,8 @@ from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size from monai.utils import BlendMode, PytorchPadMode, fall_back_tuple +__all__ = ["sliding_window_inference"] + def sliding_window_inference( inputs: torch.Tensor, diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 61e288cf0c..a0d626f45b 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -10,9 +10,9 @@ # limitations under the License. from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix -from .hausdorff_distance import compute_hausdorff_distance +from .hausdorff_distance import * from .meandice import DiceMetric, compute_meandice from .occlusion_sensitivity import compute_occlusion_sensitivity from .rocauc import compute_roc_auc -from .surface_distance import compute_average_surface_distance +from .surface_distance import SurfaceDistanceMetric, compute_average_surface_distance from .utils import * diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py index 8d2304cea3..916a07439f 100644 --- a/monai/metrics/confusion_matrix.py +++ b/monai/metrics/confusion_matrix.py @@ -15,6 +15,7 @@ import torch from monai.metrics.utils import * +from monai.utils import MetricReduction class ConfusionMatrixMetric: @@ -256,17 +257,15 @@ def compute_confusion_matrix_metric(metric_name: str, confusion_matrix: torch.Te elif metric == "mk": ppv = torch.where((tp + fp) > 0, tp / (tp + fp), nan_tensor) npv = torch.where((tn + fn) > 0, tn / (tn + fn), nan_tensor) - npv = tn / (tn + fn) numerator = ppv + npv - 1.0 denominator = 1.0 else: raise NotImplementedError("the metric is not implemented.") if isinstance(denominator, torch.Tensor): - result = torch.where(denominator != 0, numerator / denominator, nan_tensor) + return torch.where(denominator != 0, numerator / denominator, nan_tensor) else: - result = numerator / denominator - return result + return numerator / denominator def check_confusion_matrix_metric_name(metric_name: str): diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 1cfaea2449..c649cd3a04 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -9,80 +9,163 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Optional, Union import numpy as np import torch -from .utils import get_mask_edges, get_surface_distance +from monai.metrics.utils import * +from monai.utils import MetricReduction + +__all__ = ["HausdorffDistanceMetric", "compute_hausdorff_distance", "compute_percent_hausdorff_distance"] + + +class HausdorffDistanceMetric: + """ + Compute Hausdorff Distance between two tensors. It can support both multi-classes and multi-labels tasks. + It supports both directed and non-directed Hausdorff distance calculation. In addition, specify the `percentile` + parameter can get the percentile of the distance. + Input `y_pred` (BNHW[D] where N is number of classes) is compared with ground truth `y` (BNHW[D]). + `y_preds` is expected to have binarized predictions and `y` should be in one-hot format. + You can use suitable transforms in ``monai.transforms.post`` first to achieve binarized values. + + Args: + include_background: whether to include distance computation on the first channel of + the predicted output. Defaults to ``False``. + distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] + the metric used to compute surface distance. Defaults to ``"euclidean"``. + percentile: an optional float number between 0 and 100. If specified, the corresponding + percentile of the Hausdorff Distance rather than the maximum result will be achieved. + Defaults to ``None``. + directed: whether to calculate directed Hausdorff distance. Defaults to ``False``. + reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``} + Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. + + """ + + def __init__( + self, + include_background: bool = False, + distance_metric: str = "euclidean", + percentile: Optional[float] = None, + directed: bool = False, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + ) -> None: + super().__init__() + self.include_background = include_background + self.distance_metric = distance_metric + self.percentile = percentile + self.directed = directed + self.reduction = reduction + + def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): + """ + Args: + y_pred: input data to compute, typical segmentation model output. + It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values + should be binarized. + y: ground truth to compute the distance. It must be one-hot format and first dim is batch. + The values should be binarized. + + Raises: + ValueError: when `y` is not a binarized tensor. + ValueError: when `y_pred` has less than three dimensions. + """ + if not torch.all(y_pred.byte() == y_pred): + warnings.warn("y_pred is not a binarized tensor here!") + if not torch.all(y.byte() == y): + raise ValueError("y should be a binarized tensor.") + dims = y_pred.ndimension() + if dims < 3: + raise ValueError("y_pred should have at least three dimensions.") + # compute (BxC) for each channel for each batch + f = compute_hausdorff_distance( + y_pred=y_pred, + y=y, + include_background=self.include_background, + distance_metric=self.distance_metric, + percentile=self.percentile, + directed=self.directed, + ) + + # do metric reduction + f, not_nans = do_metric_reduction(f, self.reduction) + return f, not_nans def compute_hausdorff_distance( - seg_pred: Union[np.ndarray, torch.Tensor], - seg_gt: Union[np.ndarray, torch.Tensor], - label_idx: int, + y_pred: Union[np.ndarray, torch.Tensor], + y: Union[np.ndarray, torch.Tensor], + include_background: bool = False, distance_metric: str = "euclidean", percentile: Optional[float] = None, directed: bool = False, ): """ - Compute the Hausdorff distance. The user has the option to calculate the - directed or non-directed Hausdorff distance. By default, the non-directed - Hausdorff distance is calculated. In addition, specify the `percentile` - parameter can get the percentile of the distance. + Compute the Hausdorff distance. Args: - seg_pred: the predicted binary or labelfield image. - seg_gt: the actual binary or labelfield image. - label_idx: for labelfield images, convert to binary with - `seg_pred = seg_pred == label_idx`. + y_pred: input data to compute, typical segmentation model output. + It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values + should be binarized. + y: ground truth to compute mean the distance. It must be one-hot format and first dim is batch. + The values should be binarized. + include_background: whether to skip distance computation on the first channel of + the predicted output. Defaults to ``False``. distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] the metric used to compute surface distance. Defaults to ``"euclidean"``. percentile: an optional float number between 0 and 100. If specified, the corresponding percentile of the Hausdorff Distance rather than the maximum result will be achieved. Defaults to ``None``. - directed: calculate directed Hausdorff distance. Defaults to ``False``. + directed: whether to calculate directed Hausdorff distance. Defaults to ``False``. """ - (edges_pred, edges_gt) = get_mask_edges(seg_pred, seg_gt, label_idx) - hd = compute_percent_hausdorff_distance(edges_pred, edges_gt, label_idx, distance_metric, percentile) - if directed: - return hd + if not include_background: + y_pred, y = ignore_background( + y_pred=y_pred, + y=y, + ) - hd2 = compute_percent_hausdorff_distance(edges_gt, edges_pred, label_idx, distance_metric, percentile) - return max(hd, hd2) + y = y.float() + y_pred = y_pred.float() + + if y.shape != y_pred.shape: + raise ValueError("y_pred and y should have same shapes.") + + batch_size, n_class = y_pred.shape[:2] + hd = np.empty((batch_size, n_class)) + for b, c in np.ndindex(batch_size, n_class): + (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c]) + distance_1 = compute_percent_hausdorff_distance(edges_pred, edges_gt, distance_metric, percentile) + if directed: + hd[b, c] = distance_1 + else: + distance_2 = compute_percent_hausdorff_distance(edges_gt, edges_pred, distance_metric, percentile) + hd[b, c] = max(distance_1, distance_2) + return torch.from_numpy(hd) def compute_percent_hausdorff_distance( edges_pred: np.ndarray, edges_gt: np.ndarray, - label_idx: int, distance_metric: str = "euclidean", percentile: Optional[float] = None, ): """ This function is used to compute the directed Hausdorff distance. - - Args: - edges_pred: the edge of the predictions. - edges_gt: the edge of the ground truth. - label_idx: for labelfield images, convert to binary with - `seg_pred = seg_pred == label_idx`. - distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] - the metric used to compute surface distance. Defaults to ``"euclidean"``. - percentile: an optional float number between 0 and 100. If specified, the corresponding - percentile of the Hausdorff Distance rather than the maximum result will be achieved. - Defaults to ``None``. """ - surface_distance = get_surface_distance(edges_pred, edges_gt, label_idx, distance_metric=distance_metric) + surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric) - # for input without foreground + # for both pred and gt do not have foreground if surface_distance.shape == (0,): - return np.inf + return np.nan if not percentile: return surface_distance.max() + elif 0 <= percentile <= 100: return np.percentile(surface_distance, percentile) else: diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 18382e7849..53716909fe 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -15,6 +15,7 @@ import torch from monai.metrics.utils import * +from monai.utils import MetricReduction class DiceMetric: @@ -30,7 +31,7 @@ class DiceMetric: Args: include_background: whether to skip Dice computation on the first channel of - the predicted output. Defaults to True. + the predicted output. Defaults to ``True``. reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, ``"sum_channel"``} Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. diff --git a/monai/metrics/occlusion_sensitivity.py b/monai/metrics/occlusion_sensitivity.py index 900cfe4645..9879f472a9 100644 --- a/monai/metrics/occlusion_sensitivity.py +++ b/monai/metrics/occlusion_sensitivity.py @@ -10,7 +10,8 @@ # limitations under the License. from collections.abc import Sequence -from typing import Union +from functools import partial +from typing import Optional, Union import numpy as np import torch @@ -18,6 +19,8 @@ try: from tqdm import trange + + trange = partial(trange, desc="Computing occlusion sensitivity") except (ImportError, AttributeError): trange = range @@ -84,7 +87,9 @@ def compute_occlusion_sensitivity( pad_val: float = 0.0, margin: Union[int, Sequence] = 2, n_batch: int = 128, - b_box: Union[Sequence, None] = None, + b_box: Optional[Sequence] = None, + stride: Union[int, Sequence] = 1, + upsample_mode: str = "nearest", ) -> np.ndarray: """ This function computes the occlusion sensitivity for a model's prediction @@ -123,6 +128,13 @@ def compute_occlusion_sensitivity( speed the analysis up, which might be useful for larger images. * Min and max are inclusive, so [0, 63, ...] will have size (64, ...). * Use -ve to use 0 for min values and im.shape[x]-1 for xth dimension. + stride: Stride for performing occlusions. Can be single value or sequence + (for varying stride in the different directions). Should be >= 1. + upsample_mode: If stride != 1 is used, we'll upsample such that the size + of the voxels in the output image match the input. Upsampling is done with + ``torch.nn.Upsample``, and mode can be set to: + * ``nearest``, ``linear``, ``bilinear``, ``bicubic`` and ``trilinear`` + * default is ``nearest``. Returns: Numpy array. If no bounding box is supplied, this will be the same size as the input image. If a bounding box is used, the output image will be @@ -147,12 +159,28 @@ def compute_occlusion_sensitivity( # If no bounding box supplied, output shape is same as input shape. # If bounding box is present, shape is max - min + 1 output_im_shape = im_shape if b_box is None else b_box_max - b_box_min + 1 - num_required_predictions = np.prod(output_im_shape) + + # Calculate the downsampled shape + if not isinstance(stride, Sequence): + stride_np = np.full_like(im_shape, stride, dtype=np.int32) + stride_np[0] = 1 # always do stride 1 in channel dimension + else: + # Convert to numpy array and check dimensions match + stride_np = np.array(stride, dtype=np.int32) + if stride_np.size != im_shape.size: + raise ValueError("Sizes of image shape and stride should match.") + + # Obviously if stride = 1, downsampled_im_shape == output_im_shape + downsampled_im_shape = np.floor(output_im_shape / stride_np).astype(np.int32) + downsampled_im_shape[downsampled_im_shape == 0] = 1 # make sure dimension sizes are >= 1 + num_required_predictions = np.prod(downsampled_im_shape) # Loop 1D over image for i in trange(num_required_predictions): # Get corresponding ND index - idx = np.unravel_index(i, output_im_shape) + idx = np.unravel_index(i, downsampled_im_shape) + # Multiply by stride + idx *= stride_np # If a bounding box is being used, we need to add on # the min to shift to start of region of interest if b_box_min is not None: @@ -178,11 +206,20 @@ def compute_occlusion_sensitivity( batch_images = [] batch_ids = [] + # Subtract from baseline + sensitivity_im = baseline - sensitivity_im + + # Reshape to match downsampled image + sensitivity_im = sensitivity_im.reshape(tuple(downsampled_im_shape)) + + # If necessary, upsample + if np.any(stride_np != 1): + output_im_shape = tuple(output_im_shape[1:]) # needs to be given as 3D tuple + upsampler = nn.Upsample(size=output_im_shape, mode=upsample_mode) + sensitivity_im = upsampler(sensitivity_im.unsqueeze(0)) + # Convert tensor to numpy sensitivity_im = sensitivity_im.cpu().numpy() - # Reshape to size of output image - sensitivity_im = sensitivity_im.reshape(output_im_shape) - - # Squeeze, subtract from baseline and return - return baseline - np.squeeze(sensitivity_im) + # Squeeze and return + return np.squeeze(sensitivity_im) diff --git a/monai/metrics/rocauc.py b/monai/metrics/rocauc.py index d5c1cf20d2..7b26560d57 100644 --- a/monai/metrics/rocauc.py +++ b/monai/metrics/rocauc.py @@ -132,16 +132,15 @@ def compute_roc_auc( average = Average(average) if average == Average.MICRO: return _calculate(y.flatten(), y_pred.flatten()) - else: - y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1) - auc_values = [_calculate(y_, y_pred_) for y_, y_pred_ in zip(y, y_pred)] - if average == Average.NONE: - return auc_values - if average == Average.MACRO: - return np.mean(auc_values) - if average == Average.WEIGHTED: - weights = [sum(y_) for y_ in y] - return np.average(auc_values, weights=weights) - raise ValueError( - f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].' - ) + y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1) + auc_values = [_calculate(y_, y_pred_) for y_, y_pred_ in zip(y, y_pred)] + if average == Average.NONE: + return auc_values + if average == Average.MACRO: + return np.mean(auc_values) + if average == Average.WEIGHTED: + weights = [sum(y_) for y_ in y] + return np.average(auc_values, weights=weights) + raise ValueError( + f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].' + ) diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 7914364b9c..8dcbe4d9f6 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -9,49 +9,141 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Union import numpy as np import torch -from .utils import get_mask_edges, get_surface_distance +from monai.metrics.utils import * +from monai.utils import MetricReduction + + +class SurfaceDistanceMetric: + """ + Compute Surface Distance between two tensors. It can support both multi-classes and multi-labels tasks. + It supports both symmetric and asymmetric surface distance calculation. + Input `y_pred` (BNHW[D] where N is number of classes) is compared with ground truth `y` (BNHW[D]). + `y_preds` is expected to have binarized predictions and `y` should be in one-hot format. + You can use suitable transforms in ``monai.transforms.post`` first to achieve binarized values. + + Args: + include_background: whether to skip distance computation on the first channel of + the predicted output. Defaults to ``False``. + symmetric: whether to calculate the symmetric average surface distance between + `seg_pred` and `seg_gt`. Defaults to ``False``. + distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] + the metric used to compute surface distance. Defaults to ``"euclidean"``. + reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``} + Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. + + """ + + def __init__( + self, + include_background: bool = False, + symmetric: bool = False, + distance_metric: str = "euclidean", + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, + ) -> None: + super().__init__() + self.include_background = include_background + self.distance_metric = distance_metric + self.symmetric = symmetric + self.reduction = reduction + + def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): + """ + Args: + y_pred: input data to compute, typical segmentation model output. + It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values + should be binarized. + y: ground truth to compute the distance. It must be one-hot format and first dim is batch. + The values should be binarized. + + Raises: + ValueError: when `y` is not a binarized tensor. + ValueError: when `y_pred` has less than three dimensions. + """ + if not torch.all(y_pred.byte() == y_pred): + warnings.warn("y_pred is not a binarized tensor here!") + if not torch.all(y.byte() == y): + raise ValueError("y should be a binarized tensor.") + dims = y_pred.ndimension() + if dims < 3: + raise ValueError("y_pred should have at least three dimensions.") + # compute (BxC) for each channel for each batch + f = compute_average_surface_distance( + y_pred=y_pred, + y=y, + include_background=self.include_background, + symmetric=self.symmetric, + distance_metric=self.distance_metric, + ) + + # do metric reduction + f, not_nans = do_metric_reduction(f, self.reduction) + return f, not_nans def compute_average_surface_distance( - seg_pred: Union[np.ndarray, torch.Tensor], - seg_gt: Union[np.ndarray, torch.Tensor], - label_idx: int, + y_pred: Union[np.ndarray, torch.Tensor], + y: Union[np.ndarray, torch.Tensor], + include_background: bool = False, symmetric: bool = False, distance_metric: str = "euclidean", ): """ - This function is used to compute the Average Surface Distance from `seg_pred` to `seg_gt` + This function is used to compute the Average Surface Distance from `y_pred` to `y` under the default setting. In addition, if sets ``symmetric = True``, the average symmetric surface distance between these two inputs will be returned. Args: - seg_pred: first binary or labelfield image. - seg_gt: second binary or labelfield image. - label_idx: for labelfield images, convert to binary with - `seg_pred = seg_pred == label_idx`. - symmetric: if calculate the symmetric average surface distance between + y_pred: input data to compute, typical segmentation model output. + It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values + should be binarized. + y: ground truth to compute mean the distance. It must be one-hot format and first dim is batch. + The values should be binarized. + include_background: whether to skip distance computation on the first channel of + the predicted output. Defaults to ``False``. + symmetric: whether to calculate the symmetric average surface distance between `seg_pred` and `seg_gt`. Defaults to ``False``. distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] the metric used to compute surface distance. Defaults to ``"euclidean"``. """ - (edges_pred, edges_gt) = get_mask_edges(seg_pred, seg_gt, label_idx) - surface_distance = get_surface_distance(edges_pred, edges_gt, label_idx, distance_metric=distance_metric) - if surface_distance.shape == (0,): - return np.inf - avg_surface_distance = surface_distance.mean() - if not symmetric: - return avg_surface_distance + if not include_background: + y_pred, y = ignore_background( + y_pred=y_pred, + y=y, + ) + + y = y.float() + y_pred = y_pred.float() + + if y.shape != y_pred.shape: + raise ValueError("y_pred and y should have same shapes.") + + batch_size, n_class = y_pred.shape[:2] + asd = np.empty((batch_size, n_class)) - surface_distance_2 = get_surface_distance(edges_gt, edges_pred, label_idx, distance_metric=distance_metric) - if surface_distance_2.shape == (0,): - return np.inf + for b, c in np.ndindex(batch_size, n_class): + (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c]) + surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric) + if surface_distance.shape == (0,): + avg_surface_distance = np.nan + else: + avg_surface_distance = surface_distance.mean() + if not symmetric: + asd[b, c] = avg_surface_distance + else: + surface_distance_2 = get_surface_distance(edges_gt, edges_pred, distance_metric=distance_metric) + if surface_distance_2.shape == (0,): + avg_surface_distance_2 = np.nan + else: + avg_surface_distance_2 = surface_distance_2.mean() + asd[b, c] = np.mean((avg_surface_distance, avg_surface_distance_2)) - avg_surface_distance_2 = surface_distance_2.mean() - return np.mean((avg_surface_distance, avg_surface_distance_2)) + return torch.from_numpy(asd) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 08450fa355..ffe6093621 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -22,6 +22,8 @@ distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt") distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") +__all__ = ["ignore_background", "do_metric_reduction", "get_mask_edges", "get_surface_distance"] + def ignore_background( y_pred: torch.Tensor, @@ -65,7 +67,7 @@ def do_metric_reduction( not_nans = (~nans).float() f[nans] = 0 - t_zero = torch.zeros(1, device=f.device, dtype=torch.float) + t_zero = torch.zeros(1, device=f.device, dtype=f.dtype) reduction = MetricReduction(reduction) if reduction == MetricReduction.MEAN: @@ -91,9 +93,7 @@ def do_metric_reduction( elif reduction == MetricReduction.SUM_CHANNEL: not_nans = not_nans.sum(dim=1) f = f.sum(dim=1) # the channel sum - elif reduction == MetricReduction.NONE: - pass - else: + elif reduction != MetricReduction.NONE: raise ValueError( f"Unsupported reduction: {reduction}, available options are " '["mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel" "none"].' @@ -104,7 +104,7 @@ def do_metric_reduction( def get_mask_edges( seg_pred: Union[np.ndarray, torch.Tensor], seg_gt: Union[np.ndarray, torch.Tensor], - label_idx: int, + label_idx: int = 1, crop: bool = True, ) -> Tuple[np.ndarray, np.ndarray]: """ @@ -141,9 +141,8 @@ def get_mask_edges( if torch.is_tensor(seg_gt): seg_gt = seg_gt.detach().cpu().numpy() - # Check non-zero number of elements and same shape - if seg_pred.size == 0 or seg_pred.shape != seg_gt.shape: - raise ValueError("Labelfields should have same shape (and non-zero number of elements)") + if seg_pred.shape != seg_gt.shape: + raise ValueError("seg_pred and seg_gt should have same shapes.") # If not binary images, convert them if seg_pred.dtype != bool: @@ -168,28 +167,16 @@ def get_mask_edges( def get_surface_distance( - edges_pred: np.ndarray, - edges_gt: np.ndarray, - label_idx: int, - crop: bool = True, + seg_pred: np.ndarray, + seg_gt: np.ndarray, distance_metric: str = "euclidean", ) -> np.ndarray: """ This function is used to compute the surface distances from `seg_pred` to `seg_gt`. - In order to improve the computing efficiency, before getting the edges, - the images can be cropped and only keep the foreground if not specifies - ``crop = False``. - Args: - edges_pred: the edge of the predictions. - edges_gt: the edge of the ground truth. - label_idx: for labelfield images, convert to binary with - `seg_pred = seg_pred == label_idx`. - crop: crop input images and only keep the foregrounds. In order to - maintain two inputs' shapes, here the bounding box is achieved - by ``(seg_pred | seg_gt)`` which represents the union set of two - images. Defaults to ``True``. + seg_pred: the edge of the predictions. + seg_gt: the edge of the ground truth. distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] the metric used to compute surface distance. Defaults to ``"euclidean"``. @@ -198,17 +185,17 @@ def get_surface_distance( - ``"taxicab"``, uses `taxicab` metric in chamfer type of transform. """ - if not np.any(edges_pred): - return np.array([]) - - if not np.any(edges_gt): - dis = np.inf * np.ones_like(edges_gt) + if not np.any(seg_gt): + dis = np.inf * np.ones_like(seg_gt) else: + if not np.any(seg_pred): + dis = np.inf * np.ones_like(seg_gt) + return dis[seg_gt] if distance_metric == "euclidean": - dis = distance_transform_edt(~edges_gt) - elif distance_metric == "chessboard" or distance_metric == "taxicab": - dis = distance_transform_cdt(~edges_gt, metric=distance_metric) + dis = distance_transform_edt(~seg_gt) + elif distance_metric in ["chessboard", "taxicab"]: + dis = distance_transform_cdt(~seg_gt, metric=distance_metric) else: raise ValueError(f"distance_metric {distance_metric} is not implemented.") - surface_distance = dis[edges_pred] - return surface_distance + + return dis[seg_pred] diff --git a/monai/networks/blocks/acti_norm.py b/monai/networks/blocks/acti_norm.py index 585726edf2..ab399d4957 100644 --- a/monai/networks/blocks/acti_norm.py +++ b/monai/networks/blocks/acti_norm.py @@ -80,7 +80,7 @@ def __init__( super().__init__() op_dict = {"A": None, "D": None, "N": None} - # define the normalisation type and the arguments to the constructor + # define the normalization type and the arguments to the constructor if norm is not None: if norm_dim is None and dropout_dim is None: raise ValueError("norm_dim or dropout_dim needs to be specified.") diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index 9125dc38cf..f400eaf3a3 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -11,5 +11,6 @@ from .convutils import * from .factories import * +from .filtering import * from .simplelayers import * from .spatial_transforms import * diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 1bb33ed9d7..41b63c55fb 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -16,7 +16,7 @@ is typically a type but can be any callable producing a layer object. The factory objects contain functions keyed to names converted to upper case, these names can be referred to as members -of the factory so that they can function as constant identifiers. eg. instance normalisation is named `Norm.INSTANCE`. +of the factory so that they can function as constant identifiers. eg. instance normalization is named `Norm.INSTANCE`. For example, to get a transpose convolution layer the name is needed and then a dimension argument is provided which is passed to the factory function: diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py new file mode 100644 index 0000000000..dcb172d892 --- /dev/null +++ b/monai/networks/layers/filtering.py @@ -0,0 +1,58 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from monai.utils.module import optional_import + +_C, _ = optional_import("monai._C") + +__all__ = ["BilateralFilter"] + + +class BilateralFilter(torch.autograd.Function): + """ + Blurs the input tensor spatially whilst preserving edges. Can run on 1D, 2D, or 3D, + tensors (on top of Batch and Channel dimensions). Two implementations are provided, + an exact solution and a much faster approximation which uses a permutohedral lattice. + + See: + https://en.wikipedia.org/wiki/Bilateral_filter + https://graphics.stanford.edu/papers/permutohedral/ + + Args: + input: input tensor. + + spatial sigma: the standard deviation of the spatial blur. Higher values can + hurt performace when not using the approximate method (see fast approx). + + color sigma: the standard deviation of the color blur. Lower values preserve + edges better whilst higher values tend to a simple gaussian spatial blur. + + fast approx: This flag chooses between two implementations. The approximate method may + produce artifacts in some scenarios whereas the exact solution may be intolerably + slow for high spatial standard deviations. + + Returns: + output (torch.Tensor): output tensor. + """ + + @staticmethod + def forward(ctx, input, spatial_sigma=5, color_sigma=0.5, fast_approx=True): + ctx.save_for_backward(spatial_sigma, color_sigma, fast_approx) + output_data = _C.bilateral_filter(input, spatial_sigma, color_sigma, fast_approx) + return output_data + + @staticmethod + def backward(ctx, grad_output): + spatial_sigma, color_sigma, fast_approx = ctx.saved_variables + grad_input = _C.bilateral_filter(grad_output, spatial_sigma, color_sigma, fast_approx) + return grad_input diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index a726975138..48012dfb1c 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -18,11 +18,82 @@ from torch.autograd import Function from monai.networks.layers.convutils import gaussian_1d, same_padding -from monai.utils import SkipMode, ensure_tuple_rep, optional_import +from monai.networks.layers.factories import Conv +from monai.utils import ( + PT_BEFORE_1_7, + ChannelMatching, + InvalidPyTorchVersionError, + SkipMode, + ensure_tuple_rep, + optional_import, +) _C, _ = optional_import("monai._C") +if not PT_BEFORE_1_7: + fft, _ = optional_import("torch.fft") + +__all__ = [ + "SkipConnection", + "Flatten", + "GaussianFilter", + "LLTM", + "Reshape", + "separable_filtering", + "HilbertTransform", + "ChannelPad", +] + + +class ChannelPad(nn.Module): + """ + Expand the input tensor's channel dimension from length `in_channels` to `out_channels`, + by padding or a projection. + """ -__all__ = ["SkipConnection", "Flatten", "GaussianFilter", "LLTM", "Reshape", "separable_filtering"] + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + mode: Union[ChannelMatching, str] = ChannelMatching.PAD, + ): + """ + + Args: + spatial_dims: number of spatial dimensions of the input image. + in_channels: number of input channels. + out_channels: number of output channels. + mode: {``"pad"``, ``"project"``} + Specifies handling residual branch and conv branch channel mismatches. Defaults to ``"pad"``. + + - ``"pad"``: with zero padding. + - ``"project"``: with a trainable conv with kernel size one. + """ + super().__init__() + self.project = None + self.pad = None + if in_channels == out_channels: + return + mode = ChannelMatching(mode) + if mode == ChannelMatching.PROJECT: + conv_type = Conv[Conv.CONV, spatial_dims] + self.project = conv_type(in_channels, out_channels, kernel_size=1) + return + if mode == ChannelMatching.PAD: + if in_channels > out_channels: + raise ValueError('Incompatible values: channel_matching="pad" and in_channels > out_channels.') + pad_1 = (out_channels - in_channels) // 2 + pad_2 = out_channels - in_channels - pad_1 + pad = [0, 0] * spatial_dims + [pad_1, pad_2] + [0, 0] + self.pad = tuple(pad) + return + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.project is not None: + return torch.as_tensor(self.project(x)) # as_tensor used to get around mypy typing bug + if self.pad is not None: + return F.pad(x, self.pad) + return x class SkipConnection(nn.Module): @@ -130,6 +201,72 @@ def _conv(input_: torch.Tensor, d: int) -> torch.Tensor: return _conv(x, spatial_dims - 1) +class HilbertTransform(nn.Module): + """ + Determine the analytical signal of a Tensor along a particular axis. + Requires PyTorch 1.7.0+ and the PyTorch FFT module (which is not included in NVIDIA PyTorch Release 20.10). + + Args: + axis: Axis along which to apply Hilbert transform. Default 2 (first spatial dimension). + N: Number of Fourier components (i.e. FFT size). Default: ``x.shape[axis]``. + """ + + def __init__(self, axis: int = 2, n: Union[int, None] = None) -> None: + + if PT_BEFORE_1_7: + raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__) + + super().__init__() + self.axis = axis + self.n = n + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Tensor or array-like to transform. Must be real and in shape ``[Batch, chns, spatial1, spatial2, ...]``. + Returns: + torch.Tensor: Analytical signal of ``x``, transformed along axis specified in ``self.axis`` using + FFT of size ``self.N``. The absolute value of ``x_ht`` relates to the envelope of ``x`` along axis ``self.axis``. + """ + + # Make input a real tensor + x = torch.as_tensor(x, device=x.device if torch.is_tensor(x) else None) + if torch.is_complex(x): + raise ValueError("x must be real.") + else: + x = x.to(dtype=torch.float) + + if (self.axis < 0) or (self.axis > len(x.shape) - 1): + raise ValueError("Invalid axis for shape of x.") + + n = x.shape[self.axis] if self.n is None else self.n + if n <= 0: + raise ValueError("N must be positive.") + x = torch.as_tensor(x, dtype=torch.complex64) + # Create frequency axis + f = torch.cat( + [ + torch.true_divide(torch.arange(0, (n - 1) // 2 + 1, device=x.device), float(n)), + torch.true_divide(torch.arange(-(n // 2), 0, device=x.device), float(n)), + ] + ) + xf = fft.fft(x, n=n, dim=self.axis) + # Create step function + u = torch.heaviside(f, torch.tensor([0.5], device=f.device)) + u = torch.as_tensor(u, dtype=x.dtype, device=u.device) + new_dims_before = self.axis + new_dims_after = len(xf.shape) - self.axis - 1 + for _ in range(new_dims_before): + u.unsqueeze_(0) + for _ in range(new_dims_after): + u.unsqueeze_(-1) + + ht = fft.ifft(xf * 2 * u, dim=self.axis) + + # Apply transform + return torch.as_tensor(ht, device=ht.device, dtype=ht.dtype) + + class GaussianFilter(nn.Module): def __init__( self, diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index a64b6d2d0a..a6b730278d 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -518,7 +518,7 @@ def forward( if spatial_size is not None: dst_size = src_size[:2] + ensure_tuple(spatial_size) - # reverse and normalise theta if needed + # reverse and normalize theta if needed if not self.normalized: theta = to_norm_affine( affine=theta, src_size=src_size[2:], dst_size=dst_size[2:], align_corners=self.align_corners diff --git a/monai/networks/nets/autoencoder.py b/monai/networks/nets/autoencoder.py index c2239450f2..8d0aadafd6 100644 --- a/monai/networks/nets/autoencoder.py +++ b/monai/networks/nets/autoencoder.py @@ -53,6 +53,10 @@ def __init__( self.inter_channels = inter_channels if inter_channels is not None else list() self.inter_dilations = list(inter_dilations or [1] * len(self.inter_channels)) + # The number of channels and strides should match + if len(channels) != len(strides): + raise ValueError("Autoencoder expects matching number of channels and strides") + self.encoded_channels = in_channels decode_channel_list = list(channels[-2::-1]) + [out_channels] diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index a70da683ba..0915785db6 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -19,6 +19,37 @@ __all__ = ["DynUNet", "DynUnet", "Dynunet"] +class DynUNetSkipLayer(nn.Module): + """ + Defines a layer in the UNet topology which combines the downsample and upsample pathways with the skip connection. + The member `next_layer` may refer to instances of this class or the final bottleneck layer at the bottom the UNet + structure. The purpose of using a recursive class like this is to get around the Torchscript restrictions on + looping over lists of layers and accumulating lists of output tensors which much be indexed. The `heads` list is + shared amongst all the instances of this class and is used to store the output from the supervision heads during + forward passes of the network. + """ + + heads: List[torch.Tensor] + + def __init__(self, index, heads, downsample, upsample, super_head, next_layer): + super().__init__() + self.downsample = downsample + self.upsample = upsample + self.next_layer = next_layer + self.super_head = super_head + self.heads = heads + self.index = index + + def forward(self, x): + downout = self.downsample(x) + nextout = self.next_layer(downout) + upout = self.upsample(nextout, downout) + + self.heads[self.index] = self.super_head(upout) + + return upout + + class DynUNet(nn.Module): """ This reimplementation of a dynamic UNet (DynUNet) is based on: @@ -93,6 +124,43 @@ def __init__( self.check_kernel_stride() self.check_deep_supr_num() + # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on + self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1) + + def create_skips(index, downsamples, upsamples, superheads, bottleneck): + """ + Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is + done recursively from the top down since a recursive nn.Module subclass is being used to be compatible + with Torchscript. Initially the length of `downsamples` will be one more than that of `superheads` + since the `input_block` is passed to this function as the first item in `downsamples`, however this + shouldn't be associated with a supervision head. + """ + + assert len(downsamples) == len(upsamples), f"{len(downsamples)} != {len(upsamples)}" + assert (len(downsamples) - len(superheads)) in (1, 0), f"{len(downsamples)}-(0,1) != {len(superheads)}" + + if len(downsamples) == 0: # bottom of the network, pass the bottleneck block + return bottleneck + elif index == 0: # don't associate a supervision head with self.input_block + current_head, rest_heads = nn.Identity(), superheads + elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one + current_head, rest_heads = nn.Identity(), superheads[1:] + else: + current_head, rest_heads = superheads[0], superheads[1:] + + # create the next layer down, this will stop at the bottleneck layer + next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck) + + return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer) + + self.skip_layers = create_skips( + 0, + [self.input_block] + list(self.downsamples), + self.upsamples[::-1], + self.deep_supervision_heads, + self.bottleneck, + ) + def check_kernel_stride(self): kernels, strides = self.kernel_size, self.strides error_msg = "length of kernel_size and strides should be the same, and no less than 3." @@ -114,29 +182,13 @@ def check_deep_supr_num(self): assert 1 <= deep_supr_num < num_up_layers, error_msg def forward(self, x): - out = self.input_block(x) - outputs = [out] - - for downsample in self.downsamples: - out = downsample(out) - outputs.insert(0, out) - - out = self.bottleneck(out) - upsample_outs = [] - - for upsample, skip in zip(self.upsamples, outputs): - out = upsample(out, skip) - upsample_outs.append(out) - + out = self.skip_layers(x) out = self.output_block(out) if self.training and self.deep_supervision: - start_output_idx = len(upsample_outs) - 1 - self.deep_supr_num - upsample_outs = upsample_outs[start_output_idx:-1][::-1] - preds = [self.deep_supervision_heads[i](out) for i, out in enumerate(upsample_outs)] - return [out] + preds + return [out] + self.heads[1 : self.deep_supr_num + 1] - return out + return [out] def get_input_block(self): return self.conv_block( diff --git a/monai/networks/nets/highresnet.py b/monai/networks/nets/highresnet.py index c2adfd237a..918b5b5349 100644 --- a/monai/networks/nets/highresnet.py +++ b/monai/networks/nets/highresnet.py @@ -9,21 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Sequence, Union +from typing import Dict, Optional, Sequence, Tuple, Union import torch import torch.nn as nn -import torch.nn.functional as F -from monai.networks.layers.convutils import same_padding -from monai.networks.layers.factories import Conv, Dropout, Norm -from monai.utils import Activation, ChannelMatching, Normalisation +from monai.networks.blocks import ADN, Convolution +from monai.networks.layers.simplelayers import ChannelPad +from monai.utils import ChannelMatching -SUPPORTED_NORM = { - Normalisation.BATCH: lambda spatial_dims: Norm[Norm.BATCH, spatial_dims], - Normalisation.INSTANCE: lambda spatial_dims: Norm[Norm.INSTANCE, spatial_dims], -} -SUPPORTED_ACTI = {Activation.RELU: nn.ReLU, Activation.PRELU: nn.PReLU, Activation.RELU6: nn.ReLU6} DEFAULT_LAYER_PARAMS_3D = ( # initial conv layer {"name": "conv_0", "n_features": 16, "kernel_size": 3}, @@ -37,64 +31,6 @@ ) -class ConvNormActi(nn.Module): - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: int, - norm_type: Optional[Union[Normalisation, str]] = None, - acti_type: Optional[Union[Activation, str]] = None, - dropout_prob: Optional[float] = None, - ) -> None: - """ - Args: - spatial_dims: number of spatial dimensions of the input image. - in_channels: number of input channels. - out_channels: number of output channels. - kernel_size: size of the convolving kernel. - norm_type: {``"batch"``, ``"instance"``} - Feature normalisation with batchnorm or instancenorm. Defaults to ``"batch"``. - acti_type: {``"relu"``, ``"prelu"``, ``"relu6"``} - Non-linear activation using ReLU or PReLU. Defaults to ``"relu"``. - dropout_prob: probability of the feature map to be zeroed - (only applies to the penultimate conv layer). - """ - - super(ConvNormActi, self).__init__() - - layers = nn.ModuleList() - - conv_type = Conv[Conv.CONV, spatial_dims] - padding_size = same_padding(kernel_size) - conv = conv_type(in_channels, out_channels, kernel_size, padding=padding_size) - layers.append(conv) - - if norm_type is not None: - norm_type = Normalisation(norm_type) - layers.append(SUPPORTED_NORM[norm_type](spatial_dims)(out_channels)) - if acti_type is not None: - acti_type = Activation(acti_type) - layers.append(SUPPORTED_ACTI[acti_type](inplace=True)) - if dropout_prob is not None: - dropout_type = Dropout[Dropout.DROPOUT, spatial_dims] - layers.append(dropout_type(p=dropout_prob)) - self.layers = nn.Sequential(*layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.as_tensor(self.layers(x)) - - -class ChannelPad(nn.Module): - def __init__(self, pad): - super().__init__() - self.pad = tuple(pad) - - def forward(self, x): - return F.pad(x, self.pad) - - class HighResBlock(nn.Module): def __init__( self, @@ -103,8 +39,8 @@ def __init__( out_channels: int, kernels: Sequence[int] = (3, 3), dilation: Union[Sequence[int], int] = 1, - norm_type: Union[Normalisation, str] = Normalisation.INSTANCE, - acti_type: Union[Activation, str] = Activation.RELU, + norm_type: Union[Tuple, str] = ("batch", {"affine": True}), + acti_type: Union[Tuple, str] = ("relu", {"inplace": True}), channel_matching: Union[ChannelMatching, str] = ChannelMatching.PAD, ) -> None: """ @@ -114,51 +50,39 @@ def __init__( out_channels: number of output channels. kernels: each integer k in `kernels` corresponds to a convolution layer with kernel size k. dilation: spacing between kernel elements. - norm_type: {``"batch"``, ``"instance"``} - Feature normalisation with batchnorm or instancenorm. Defaults to ``"instance"``. + norm_type: feature normalization type and arguments. + Defaults to ``("batch", {"affine": True})``. acti_type: {``"relu"``, ``"prelu"``, ``"relu6"``} Non-linear activation using ReLU or PReLU. Defaults to ``"relu"``. channel_matching: {``"pad"``, ``"project"``} Specifies handling residual branch and conv branch channel mismatches. Defaults to ``"pad"``. - ``"pad"``: with zero padding. - - ``"project"``: with a trainable conv with kernel size. + - ``"project"``: with a trainable conv with kernel size one. Raises: ValueError: When ``channel_matching=pad`` and ``in_channels > out_channels``. Incompatible values. """ super(HighResBlock, self).__init__() - conv_type = Conv[Conv.CONV, spatial_dims] - norm_type = Normalisation(norm_type) - acti_type = Activation(acti_type) - - self.project = None - self.pad = None - - if in_channels != out_channels: - channel_matching = ChannelMatching(channel_matching) - - if channel_matching == ChannelMatching.PROJECT: - self.project = conv_type(in_channels, out_channels, kernel_size=1) - - if channel_matching == ChannelMatching.PAD: - if in_channels > out_channels: - raise ValueError('Incompatible values: channel_matching="pad" and in_channels > out_channels.') - pad_1 = (out_channels - in_channels) // 2 - pad_2 = out_channels - in_channels - pad_1 - pad = [0, 0] * spatial_dims + [pad_1, pad_2] + [0, 0] - self.pad = ChannelPad(pad) + self.chn_pad = ChannelPad( + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, mode=channel_matching + ) layers = nn.ModuleList() _in_chns, _out_chns = in_channels, out_channels for kernel_size in kernels: - layers.append(SUPPORTED_NORM[norm_type](spatial_dims)(_in_chns)) - layers.append(SUPPORTED_ACTI[acti_type](inplace=True)) layers.append( - conv_type( - _in_chns, _out_chns, kernel_size, padding=same_padding(kernel_size, dilation), dilation=dilation + ADN(ordering="NA", in_channels=_in_chns, act=acti_type, norm=norm_type, norm_dim=spatial_dims) + ) + layers.append( + Convolution( + dimensions=spatial_dims, + in_channels=_in_chns, + out_channels=_out_chns, + kernel_size=kernel_size, + dilation=dilation, ) ) _in_chns = _out_chns @@ -167,14 +91,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: x_conv: torch.Tensor = self.layers(x) - - if self.project is not None: - return x_conv + torch.as_tensor(self.project(x)) # as_tensor used to get around mypy typing bug - - if self.pad is not None: - return x_conv + torch.as_tensor(self.pad(x)) - - return x_conv + x + return x_conv + torch.as_tensor(self.chn_pad(x)) class HighResNet(nn.Module): @@ -191,13 +108,18 @@ class HighResNet(nn.Module): spatial_dims: number of spatial dimensions of the input image. in_channels: number of input channels. out_channels: number of output channels. - norm_type: {``"batch"``, ``"instance"``} - Feature normalisation with batchnorm or instancenorm. Defaults to ``"batch"``. - acti_type: {``"relu"``, ``"prelu"``, ``"relu6"``} - Non-linear activation using ReLU or PReLU. Defaults to ``"relu"``. + norm_type: feature normalization type and arguments. + Defaults to ``("batch", {"affine": True})``. + acti_type: activation type and arguments. + Defaults to ``("relu", {"inplace": True})``. dropout_prob: probability of the feature map to be zeroed (only applies to the penultimate conv layer). layer_params: specifying key parameters of each layer/block. + channel_matching: {``"pad"``, ``"project"``} + Specifies handling residual branch and conv branch channel mismatches. Defaults to ``"pad"``. + + - ``"pad"``: with zero padding. + - ``"project"``: with a trainable conv with kernel size one. """ def __init__( @@ -205,10 +127,11 @@ def __init__( spatial_dims: int = 3, in_channels: int = 1, out_channels: int = 1, - norm_type: Union[Normalisation, str] = Normalisation.BATCH, - acti_type: Union[Activation, str] = Activation.RELU, - dropout_prob: Optional[float] = None, + norm_type: Union[str, tuple] = ("batch", {"affine": True}), + acti_type: Union[str, tuple] = ("relu", {"inplace": True}), + dropout_prob: Optional[Union[Tuple, str, float]] = 0.0, layer_params: Sequence[Dict] = DEFAULT_LAYER_PARAMS_3D, + channel_matching: Union[ChannelMatching, str] = ChannelMatching.PAD, ) -> None: super(HighResNet, self).__init__() @@ -218,14 +141,14 @@ def __init__( params = layer_params[0] _in_chns, _out_chns = in_channels, params["n_features"] blocks.append( - ConvNormActi( - spatial_dims, - _in_chns, - _out_chns, + Convolution( + dimensions=spatial_dims, + in_channels=_in_chns, + out_channels=_out_chns, kernel_size=params["kernel_size"], - norm_type=norm_type, - acti_type=acti_type, - dropout_prob=None, + adn_ordering="NA", + act=acti_type, + norm=norm_type, ) ) @@ -236,13 +159,14 @@ def __init__( for _ in range(params["repeat"]): blocks.append( HighResBlock( - spatial_dims, - _in_chns, - _out_chns, - params["kernels"], + spatial_dims=spatial_dims, + in_channels=_in_chns, + out_channels=_out_chns, + kernels=params["kernels"], dilation=_dilation, norm_type=norm_type, acti_type=acti_type, + channel_matching=channel_matching, ) ) _in_chns = _out_chns @@ -251,28 +175,30 @@ def __init__( params = layer_params[-2] _in_chns, _out_chns = _out_chns, params["n_features"] blocks.append( - ConvNormActi( - spatial_dims, - _in_chns, - _out_chns, + Convolution( + dimensions=spatial_dims, + in_channels=_in_chns, + out_channels=_out_chns, kernel_size=params["kernel_size"], - norm_type=norm_type, - acti_type=acti_type, - dropout_prob=dropout_prob, + adn_ordering="NAD", + act=acti_type, + norm=norm_type, + dropout=dropout_prob, ) ) params = layer_params[-1] _in_chns = _out_chns blocks.append( - ConvNormActi( - spatial_dims, - _in_chns, - out_channels, + Convolution( + dimensions=spatial_dims, + in_channels=_in_chns, + out_channels=out_channels, kernel_size=params["kernel_size"], - norm_type=norm_type, - acti_type=None, - dropout_prob=None, + adn_ordering="NAD", + act=acti_type, + norm=norm_type, + dropout=dropout_prob, ) ) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index a46e8e66d7..1bcccd084c 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -20,6 +20,17 @@ from monai.utils import ensure_tuple_size +__all__ = [ + "one_hot", + "slice_channels", + "predict_segmentation", + "normalize_transform", + "to_norm_affine", + "normal_init", + "icnr_init", + "pixelshuffle", +] + def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor: """ @@ -72,11 +83,10 @@ def predict_segmentation( """ if not mutually_exclusive: return (cast(torch.Tensor, logits >= threshold)).int() - else: - if logits.shape[1] == 1: - warnings.warn("single channel prediction, `mutually_exclusive=True` ignored, use threshold instead.") - return (cast(torch.Tensor, logits >= threshold)).int() - return logits.argmax(1, keepdim=True) + if logits.shape[1] == 1: + warnings.warn("single channel prediction, `mutually_exclusive=True` ignored, use threshold instead.") + return (cast(torch.Tensor, logits >= threshold)).int() + return logits.argmax(1, keepdim=True) def normalize_transform( @@ -145,8 +155,7 @@ def to_norm_affine( src_xform = normalize_transform(src_size, affine.device, affine.dtype, align_corners) dst_xform = normalize_transform(dst_size, affine.device, affine.dtype, align_corners) - new_affine = src_xform @ affine @ torch.inverse(dst_xform) - return new_affine + return src_xform @ affine @ torch.inverse(dst_xform) def normal_init( diff --git a/monai/optimizers/utils.py b/monai/optimizers/utils.py index fdd8bc072f..4cafa45749 100644 --- a/monai/optimizers/utils.py +++ b/monai/optimizers/utils.py @@ -13,7 +13,9 @@ import torch -from monai.utils import ensure_tuple +from monai.utils import ensure_tuple, ensure_tuple_rep + +__all__ = ["generate_param_groups"] def generate_param_groups( @@ -35,9 +37,11 @@ def generate_param_groups( match_types: a list of tags to identify the matching type corresponding to the `layer_matches` functions, can be "select" or "filter". lr_values: a list of LR values corresponding to the `layer_matches` functions. - include_others: whether to incude the rest layers as the last group, default to True. + include_others: whether to include the rest layers as the last group, default to True. + + It's mainly used to set different LR values for different network elements, for example: - It's mainly used to set different init LR values for different network elements, for example:: + .. code-block:: python net = Unet(dimensions=3, in_channels=1, out_channels=3, channels=[2, 2, 2], strides=[1, 1, 1]) print(net) # print out network components to select expected items @@ -48,27 +52,41 @@ def generate_param_groups( match_types=["select", "filter"], lr_values=[1e-2, 1e-3], ) + # the groups will be a list of dictionaries: + # [{'params': , 'lr': 0.01}, + # {'params': , 'lr': 0.001}, + # {'params': }] optimizer = torch.optim.Adam(params, 1e-4) """ layer_matches = ensure_tuple(layer_matches) - match_types = ensure_tuple(match_types) - lr_values = ensure_tuple(lr_values) - if len(layer_matches) != len(lr_values) or len(layer_matches) != len(match_types): - raise ValueError("length of layer_match callable functions, match types and LR values should be the same.") + match_types = ensure_tuple_rep(match_types, len(layer_matches)) + lr_values = ensure_tuple_rep(lr_values, len(layer_matches)) + + def _get_select(f): + def _select(): + return f(network).parameters() + + return _select + + def _get_filter(f): + def _filter(): + return filter(f, network.named_parameters()) + + return _filter params = list() _layers = list() for func, ty, lr in zip(layer_matches, match_types, lr_values): - if ty == "select": - layer_params = func(network).parameters() - elif ty == "filter": - layer_params = filter(func, network.named_parameters()) + if ty.lower() == "select": + layer_params = _get_select(func) + elif ty.lower() == "filter": + layer_params = _get_filter(func) else: - raise ValueError(f"unsuppoted layer match type: {ty}.") + raise ValueError(f"unsupported layer match type: {ty}.") - params.append({"params": layer_params, "lr": lr}) - _layers.extend(list(map(id, layer_params))) + params.append({"params": layer_params(), "lr": lr}) + _layers.extend(list(map(id, layer_params()))) if include_others: params.append({"params": filter(lambda p: id(p) not in _layers, network.parameters())}) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index d5cae18a53..13d2e640bc 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -22,6 +22,8 @@ from monai.transforms.utils import apply_transform from monai.utils import MAX_SEED, ensure_tuple, get_seed +__all__ = ["Transform", "Randomizable", "Compose", "MapTransform"] + class Transform(ABC): """ @@ -192,7 +194,7 @@ class Compose(Randomizable, Transform): set of functions must be called as if it were a sequence. Example: images and labels - Images typically require some kind of normalisation that labels do not. + Images typically require some kind of normalization that labels do not. Both are then typically augmented through the use of random rotations, flips, and deformations. Compose can be used with a series of transforms that take a dictionary diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index d8075e5d01..4c69a61b15 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -632,3 +632,43 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ return self.padder(self.cropper(img), mode=mode) + + +class BoundingRect(Transform): + """ + Compute coordinates of axis-aligned bounding rectangles from input image `img`. + The output format of the coordinates is (shape is [channel, 2 * spatial dims]): + + [[1st_spatial_dim_start, 1st_spatial_dim_end, + 2nd_spatial_dim_start, 2nd_spatial_dim_end, + ..., + Nth_spatial_dim_start, Nth_spatial_dim_end], + + ... + + [1st_spatial_dim_start, 1st_spatial_dim_end, + 2nd_spatial_dim_start, 2nd_spatial_dim_end, + ..., + Nth_spatial_dim_start, Nth_spatial_dim_end]] + + The bounding boxes edges are aligned with the input image edges. + This function returns [-1, -1, ...] if there's no positive intensity. + + Args: + select_fn: function to select expected foreground, default is to select values > 0. + """ + + def __init__(self, select_fn: Callable = lambda x: x > 0) -> None: + self.select_fn = select_fn + + def __call__(self, img: np.ndarray) -> np.ndarray: + """ + See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`. + """ + bbox = list() + + for channel in range(img.shape[0]): + start_, end_ = generate_spatial_bounding_box(img, select_fn=self.select_fn, channel_indices=channel) + bbox.append([i for k in zip(start_, end_) for i in k]) + + return np.stack(bbox, axis=0) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 7d0b4f85cd..8e927eb605 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -24,6 +24,7 @@ from monai.transforms.compose import MapTransform, Randomizable from monai.transforms.croppad.array import ( BorderPad, + BoundingRect, CenterSpatialCrop, DivisiblePad, ResizeWithPadOrCrop, @@ -580,6 +581,37 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d +class BoundingRectd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.BoundingRect`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: monai.transforms.MapTransform + bbox_key_postfix: the output bounding box coordinates will be + written to the value of `{key}_{bbox_key_postfix}`. + select_fn: function to select expected foreground, default is to select values > 0. + """ + + def __init__(self, keys: KeysCollection, bbox_key_postfix: str = "bbox", select_fn: Callable = lambda x: x > 0): + super().__init__(keys=keys) + self.bbox = BoundingRect(select_fn=select_fn) + self.bbox_key_postfix = bbox_key_postfix + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + """ + See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`. + """ + d = dict(data) + for key in self.keys: + bbox = self.bbox(d[key]) + key_to_add = f"{key}_{self.bbox_key_postfix}" + if key_to_add in d: + raise KeyError(f"Bounding box data with key {key_to_add} already exists.") + d[key_to_add] = bbox + return d + + SpatialPadD = SpatialPadDict = SpatialPadd BorderPadD = BorderPadDict = BorderPadd DivisiblePadD = DivisiblePadDict = DivisiblePadd @@ -591,3 +623,4 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda RandWeightedCropD = RandWeightedCropDict = RandWeightedCropd RandCropByPosNegLabelD = RandCropByPosNegLabelDict = RandCropByPosNegLabeld ResizeWithPadOrCropD = ResizeWithPadOrCropDict = ResizeWithPadOrCropd +BoundingRectD = BoundingRectDict = BoundingRectd diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index a464109417..ac2d1e46fd 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -20,10 +20,10 @@ import numpy as np import torch -from monai.networks.layers import GaussianFilter +from monai.networks.layers import GaussianFilter, HilbertTransform from monai.transforms.compose import Randomizable, Transform from monai.transforms.utils import rescale_array -from monai.utils import dtype_torch_to_numpy, ensure_tuple_size +from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, dtype_torch_to_numpy, ensure_tuple_size class RandGaussianNoise(Randomizable, Transform): @@ -200,6 +200,7 @@ class NormalizeIntensity(Transform): nonzero: whether only normalize non-zero values. channel_wise: if using calculated mean and std, calculate on each channel separately or calculate on the entire image directly. + dtype: output data type, defaut to float32. """ def __init__( @@ -208,11 +209,13 @@ def __init__( divisor: Optional[Sequence] = None, nonzero: bool = False, channel_wise: bool = False, + dtype: np.dtype = np.float32, ) -> None: self.subtrahend = subtrahend self.divisor = divisor self.nonzero = nonzero self.channel_wise = channel_wise + self.dtype = dtype def _normalize(self, img: np.ndarray, sub=None, div=None) -> np.ndarray: slices = (img != 0) if self.nonzero else np.ones(img.shape, dtype=np.bool_) @@ -252,7 +255,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: else: img = self._normalize(img, self.subtrahend, self.divisor) - return img + return img.astype(self.dtype) class ThresholdIntensity(Transform): @@ -509,6 +512,46 @@ def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> n return img * mask_data_ +class DetectEnvelope(Transform): + """ + Find the envelope of the input data along the requested axis using a Hilbert transform. + Requires PyTorch 1.7.0+ and the PyTorch FFT module (which is not included in NVIDIA PyTorch Release 20.10). + + Args: + axis: Axis along which to detect the envelope. Default 1, i.e. the first spatial dimension. + N: FFT size. Default img.shape[axis]. Input will be zero-padded or truncated to this size along dimension + ``axis``. + + """ + + def __init__(self, axis: int = 1, n: Union[int, None] = None) -> None: + + if PT_BEFORE_1_7: + raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__) + + if axis < 0: + raise ValueError("axis must be zero or positive.") + + self.axis = axis + self.n = n + + def __call__(self, img: np.ndarray) -> np.ndarray: + """ + + Args: + img: numpy.ndarray containing input data. Must be real and in shape [channels, spatial1, spatial2, ...]. + + Returns: + np.ndarray containing envelope of data in img along the specified axis. + + """ + # add one to transform axis because a batch axis will be added at dimension 0 + hilbert_transform = HilbertTransform(self.axis + 1, self.n) + # convert to Tensor and add Batch axis expected by HilbertTransform + input_data = torch.as_tensor(np.ascontiguousarray(img)).unsqueeze(0) + return np.abs(hilbert_transform(input_data).squeeze(0).numpy()) + + class GaussianSmooth(Transform): """ Apply Gaussian smooth to the input data based on specified `sigma` parameter. diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index f0030849d9..64f641ecd1 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -227,6 +227,7 @@ class NormalizeIntensityd(MapTransform): nonzero: whether only normalize non-zero values. channel_wise: if using calculated mean and std, calculate on each channel separately or calculate on the entire image directly. + dtype: output data type, defaut to float32. """ def __init__( @@ -236,9 +237,10 @@ def __init__( divisor: Optional[np.ndarray] = None, nonzero: bool = False, channel_wise: bool = False, + dtype: np.dtype = np.float32, ) -> None: super().__init__(keys) - self.normalizer = NormalizeIntensity(subtrahend, divisor, nonzero, channel_wise) + self.normalizer = NormalizeIntensity(subtrahend, divisor, nonzero, channel_wise, dtype) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 22e2b8e3d6..fd44555fa7 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -70,13 +70,14 @@ def __init__( if reader is not None: if isinstance(reader, str): supported_readers = { - "NibabelReader": NibabelReader, - "PILReader": PILReader, - "ITKReader": ITKReader, - "NumpyReader": NumpyReader, + "nibabelreader": NibabelReader, + "pilreader": PILReader, + "itkreader": ITKReader, + "numpyreader": NumpyReader, } + reader = reader.lower() if reader not in supported_readers: - raise ValueError(f"unsupported reader type: {reader}.") + raise ValueError(f"unsupported reader type: {reader}, available options: {supported_readers}.") self.register(supported_readers[reader](*args, **kwargs)) else: self.register(reader) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index f417fabffa..8daad86dd2 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -20,8 +20,8 @@ import numpy as np import torch -from monai.transforms.compose import Transform -from monai.transforms.utils import map_binary_to_indices +from monai.transforms.compose import Randomizable, Transform +from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices from monai.utils import ensure_tuple # Generic type which can represent either a numpy.ndarray or a torch.Tensor @@ -535,3 +535,65 @@ def __call__( bg_indices = np.stack([np.unravel_index(i, output_shape) for i in bg_indices]) return fg_indices, bg_indices + + +class AddExtremePointsChannel(Transform, Randomizable): + """ + Add extreme points of label to the image as a new channel. This transform generates extreme + point from label and applies a gaussian filter. The pixel values in points image are rescaled + to range [rescale_min, rescale_max] and added as a new channel to input image. The algorithm is + described in Roth et al., Going to Extremes: Weakly Supervised Medical Image Segmentation + https://arxiv.org/abs/2009.11988. + + This transform only supports single channel labels (1, spatial_dim1, [spatial_dim2, ...]). The + background ``index`` is ignored when calculating extreme points. + + Args: + background: Class index of background label, defaults to 0. + pert: Random perturbation amount to add to the points, defaults to 0.0. + + Raises: + ValueError: When no label image provided. + ValueError: When label image is not single channel. + """ + + def __init__(self, background: int = 0, pert: float = 0.0) -> None: + self._background = background + self._pert = pert + self._points: List[Tuple[int, ...]] = [] + + def randomize(self, label: np.ndarray) -> None: + self._points = get_extreme_points(label, rand_state=self.R, background=self._background, pert=self._pert) + + def __call__( + self, + img: np.ndarray, + label: Optional[np.ndarray] = None, + sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 3.0, + rescale_min: float = -1.0, + rescale_max: float = 1.0, + ) -> np.ndarray: + """ + Args: + img: the image that we want to add new channel to. + label: label image to get extreme points from. Shape must be + (1, spatial_dim1, [, spatial_dim2, ...]). Doesn't support one-hot labels. + sigma: if a list of values, must match the count of spatial dimensions of input data, + and apply every value in the list to 1 spatial dimension. if only 1 value provided, + use it for all spatial dimensions. + rescale_min: minimum value of output data. + rescale_max: maximum value of output data. + """ + if label is None: + raise ValueError("This transform requires a label array!") + if label.shape[0] != 1: + raise ValueError("Only supports single channel labels!") + + # Generate extreme points + self.randomize(label[0, :]) + + points_image = extreme_points_to_image( + points=self._points, label=label, sigma=sigma, rescale_min=rescale_min, rescale_max=rescale_max + ) + + return np.concatenate([img, points_image], axis=0) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index e6a9da8076..28d7452e77 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -17,13 +17,14 @@ import copy import logging -from typing import Callable, Dict, Hashable, Mapping, Optional, Sequence, Union +from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch from monai.config import KeysCollection -from monai.transforms.compose import MapTransform +from monai.transforms import extreme_points_to_image, get_extreme_points +from monai.transforms.compose import MapTransform, Randomizable from monai.transforms.utility.array import ( AddChannel, AsChannelFirst, @@ -158,7 +159,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda class SplitChanneld(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SplitChannel`. - All the input specified by `keys` should be splitted into same count of data. + All the input specified by `keys` should be split into same count of data. """ @@ -636,6 +637,93 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d +class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): + """ + Convert labels to multi channels based on brats18 classes: + label 1 is the necrotic and non-enhancing tumor core + label 2 is the the peritumoral edema + label 4 is the GD-enhancing tumor + The possible classes are TC (Tumor core), WT (Whole tumor) + and ET (Enhancing tumor). + """ + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = dict(data) + for key in self.keys: + result = list() + # merge labels 1 (tumor non-enh) and 4 (tumor enh) to TC + result.append(np.logical_or(d[key] == 1, d[key] == 4)) + # merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT + result.append(np.logical_or(np.logical_or(d[key] == 1, d[key] == 4), d[key] == 2)) + # label 4 is ET + result.append(d[key] == 4) + d[key] = np.stack(result, axis=0).astype(np.float32) + return d + + +class AddExtremePointsChanneld(Randomizable, MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.AddExtremePointsChannel`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + label_key: key to label source to get the extreme points. + background: Class index of background label, defaults to 0. + pert: Random perturbation amount to add to the points, defaults to 0.0. + sigma: if a list of values, must match the count of spatial dimensions of input data, + and apply every value in the list to 1 spatial dimension. if only 1 value provided, + use it for all spatial dimensions. + rescale_min: minimum value of output data. + rescale_max: maximum value of output data. + + """ + + def __init__( + self, + keys: KeysCollection, + label_key: str, + background: int = 0, + pert: float = 0.0, + sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 3.0, + rescale_min: float = -1.0, + rescale_max: float = 1.0, + ): + super().__init__(keys) + self.background = background + self.pert = pert + self.points: List[Tuple[int, ...]] = [] + self.label_key = label_key + self.sigma = sigma + self.rescale_min = rescale_min + self.rescale_max = rescale_max + + def randomize(self, label: np.ndarray) -> None: + self.points = get_extreme_points(label, rand_state=self.R, background=self.background, pert=self.pert) + + def __call__(self, data): + d = dict(data) + label = d[self.label_key] + if label.shape[0] != 1: + raise ValueError("Only supports single channel labels!") + + # Generate extreme points + self.randomize(label[0, :]) + + for key in data.keys(): + if key in self.keys: + img = d[key] + points_image = extreme_points_to_image( + points=self.points, + label=label, + sigma=self.sigma, + rescale_min=self.rescale_min, + rescale_max=self.rescale_max, + ) + d[key] = np.concatenate([img, points_image], axis=0) + return d + + IdentityD = IdentityDict = Identityd AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd @@ -653,3 +741,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda LambdaD = LambdaDict = Lambdad LabelToMaskD = LabelToMaskDict = LabelToMaskd FgBgToIndicesD = FgBgToIndicesDict = FgBgToIndicesd +ConvertToMultiChannelBasedOnBratsClassesD = ( + ConvertToMultiChannelBasedOnBratsClassesDict +) = ConvertToMultiChannelBasedOnBratsClassesd +AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 15d31016f2..3b552f543c 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import random import warnings from typing import Callable, List, Optional, Sequence, Tuple, Union @@ -17,6 +18,7 @@ import torch from monai.config import IndexSelection +from monai.networks.layers import GaussianFilter from monai.utils import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, min_version, optional_import measure, _ = optional_import("skimage.measure", "0.14.2", min_version) @@ -59,10 +61,7 @@ def zero_margins(img: np.ndarray, margin: int) -> bool: if np.any(img[:, :, :margin]) or np.any(img[:, :, -margin:]): return False - if np.any(img[:, :margin, :]) or np.any(img[:, -margin:, :]): - return False - - return True + return not np.any(img[:, :margin, :]) and not np.any(img[:, -margin:, :]) def rescale_array( @@ -261,8 +260,7 @@ def weighted_patch_samples( idx = v.searchsorted(r_state.random(n_samples) * v[-1], side="right") # compensate 'valid' mode diff = np.minimum(win_size, img_size) // 2 - centers = [np.unravel_index(i, v_size) + diff for i in np.asarray(idx, dtype=np.int)] - return centers + return [np.unravel_index(i, v_size) + diff for i in np.asarray(idx, dtype=np.int)] def generate_pos_neg_label_crop_centers( @@ -426,7 +424,7 @@ def create_rotate(spatial_dims: int, radians: Union[Sequence[float], float]) -> return np.array([[cos_, -sin_, 0.0], [sin_, cos_, 0.0], [0.0, 0.0, 1.0]]) raise ValueError("radians must be non empty.") - if spatial_dims == 3: + elif spatial_dims == 3: affine = None if len(radians) >= 1: sin_, cos_ = np.sin(radians[0]), np.cos(radians[0]) @@ -465,7 +463,7 @@ def create_shear(spatial_dims: int, coefs: Union[Sequence[float], float]) -> np. if spatial_dims == 2: coefs = ensure_tuple_size(coefs, dim=2, pad_val=0.0) return np.array([[1, coefs[0], 0.0], [coefs[1], 1.0, 0.0], [0.0, 0.0, 1.0]]) - if spatial_dims == 3: + elif spatial_dims == 3: coefs = ensure_tuple_size(coefs, dim=6, pad_val=0.0) return np.array( [ @@ -515,6 +513,13 @@ def generate_spatial_bounding_box( generate the spatial bounding box of foreground in the image with start-end positions. Users can define arbitrary function to select expected foreground from the whole image or specified channels. And it can also add margin to every dim of the bounding box. + The output format of the coordinates is: + + [1st_spatial_dim_start, 2nd_spatial_dim_start, ..., Nth_spatial_dim_start], + [1st_spatial_dim_end, 2nd_spatial_dim_end, ..., Nth_spatial_dim_end] + + The bounding boxes edges are aligned with the input image edges. + This function returns [-1, -1, ...], [-1, -1, ...] if there's no positive intensity. Args: img: source image to generate bounding box from. @@ -523,17 +528,26 @@ def generate_spatial_bounding_box( of image. if None, select foreground on the whole image. margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. """ - data = img[[*(ensure_tuple(channel_indices))]] if channel_indices is not None else img + data = img[list(ensure_tuple(channel_indices))] if channel_indices is not None else img data = np.any(select_fn(data), axis=0) - nonzero_idx = np.nonzero(data) - margin = ensure_tuple_rep(margin, data.ndim) - - box_start = list() - box_end = list() - for i in range(data.ndim): - assert len(nonzero_idx[i]) > 0, f"did not find nonzero index at spatial dim {i}" - box_start.append(max(0, np.min(nonzero_idx[i]) - margin[i])) - box_end.append(min(data.shape[i], np.max(nonzero_idx[i]) + margin[i] + 1)) + ndim = len(data.shape) + margin = ensure_tuple_rep(margin, ndim) + for m in margin: + if m < 0: + raise ValueError("margin value should not be negative number.") + + box_start = [0] * ndim + box_end = [0] * ndim + + for di, ax in enumerate(itertools.combinations(reversed(range(ndim)), ndim - 1)): + dt = data.any(axis=ax) + if not np.any(dt): + return [-1] * ndim, [-1] * ndim + + min_d = max(np.argmax(dt) - margin[di], 0) + max_d = max(data.shape[di] - max(np.argmax(dt[::-1]) - margin[di], 0), min_d + 1) + box_start[di], box_end[di] = min_d, max_d + return box_start, box_end @@ -554,3 +568,97 @@ def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Option if item.max() != 0: largest_cc[i, ...] = item == (np.argmax(np.bincount(item.flat)[1:]) + 1) return torch.as_tensor(largest_cc, device=img.device) + + +def get_extreme_points( + img: np.ndarray, rand_state: np.random.RandomState = np.random, background: int = 0, pert: float = 0.0 +) -> List[Tuple[int, ...]]: + """ + Generate extreme points from an image. These are used to generate initial segmentation + for annotation models. An optional perturbation can be passed to simulate user clicks. + + Args: + img: + Image to generate extreme points from. Expected Shape is ``(spatial_dim1, [, spatial_dim2, ...])``. + rand_state: `np.random.RandomState` object used to select random indices. + background: Value to be consider as background, defaults to 0. + pert: Random perturbation amount to add to the points, defaults to 0.0. + + Returns: + A list of extreme points, its length is equal to 2 * spatial dimension of input image. + The output format of the coordinates is: + + [1st_spatial_dim_min, 1st_spatial_dim_max, 2nd_spatial_dim_min, ..., Nth_spatial_dim_max] + + Raises: + ValueError: When the input image does not have any foreground pixel. + """ + indices = np.where(img != background) + if np.size(indices[0]) == 0: + raise ValueError("get_extreme_points: no foreground object in mask!") + + def _get_point(val, dim): + """ + Select one of the indices within slice containing val. + + Args: + val : value for comparison + dim : dimension in which to look for value + """ + idx = rand_state.choice(np.where(indices[dim] == val)[0]) + pt = [] + for j in range(img.ndim): + # add +- pert to each dimension + val = int(indices[j][idx] + 2.0 * pert * (rand_state.rand() - 0.5)) + val = max(val, 0) + val = min(val, img.shape[j] - 1) + pt.append(val) + return pt + + points = [] + for i in range(img.ndim): + points.append(tuple(_get_point(np.min(indices[i][...]), i))) + points.append(tuple(_get_point(np.max(indices[i][...]), i))) + + return points + + +def extreme_points_to_image( + points: List[Tuple[int, ...]], + label: np.ndarray, + sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 0.0, + rescale_min: float = -1.0, + rescale_max: float = 1.0, +): + """ + Please refer to :py:class:`monai.transforms.AddExtremePointsChannel` for the usage. + + Applies a gaussian filter to the extreme points image. Then the pixel values in points image are rescaled + to range [rescale_min, rescale_max]. + + Args: + points: Extreme points of the object/organ. + label: label image to get extreme points from. Shape must be + (1, spatial_dim1, [, spatial_dim2, ...]). Doesn't support one-hot labels. + sigma: if a list of values, must match the count of spatial dimensions of input data, + and apply every value in the list to 1 spatial dimension. if only 1 value provided, + use it for all spatial dimensions. + rescale_min: minimum value of output data. + rescale_max: maximum value of output data. + """ + # points to image + points_image = torch.zeros(label.shape[1:], dtype=torch.float) + for p in points: + points_image[p] = 1.0 + + # add channel and add batch + points_image = points_image.unsqueeze(0).unsqueeze(0) + gaussian_filter = GaussianFilter(label.ndim - 1, sigma=sigma) + points_image = gaussian_filter(points_image).squeeze(0).detach().numpy() + + # rescale the points image to [rescale_min, rescale_max] + min_intensity = np.min(points_image) + max_intensity = np.max(points_image) + points_image = (points_image - min_intensity) / (max_intensity - min_intensity) + points_image = points_image * (rescale_max - rescale_min) + rescale_min + return points_image diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 8f27a7faf3..d2d3e41d67 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -15,3 +15,4 @@ from .enums import * from .misc import * from .module import * +from .profiling import * diff --git a/monai/utils/decorators.py b/monai/utils/decorators.py index 0d14eb0dd3..35a594d077 100644 --- a/monai/utils/decorators.py +++ b/monai/utils/decorators.py @@ -9,28 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time from functools import wraps -def timing(func): - """ - This simple timing function decorator prints to stdout/logfile (it uses printFlush) how many seconds a call to the - original function took to execute, as well as the name before and after the call. - """ - - @wraps(func) - def timingwrap(*args, **kwargs): - print(func.__name__, flush=True) - start = time.perf_counter() - res = func(*args, **kwargs) - end = time.perf_counter() - print(func.__name__, "dT (s) =", (end - start), flush=True) - return res - - return timingwrap - - class RestartGenerator: """ Wraps a generator callable which will be called whenever this class is iterated and its result returned. This is diff --git a/monai/utils/enums.py b/monai/utils/enums.py index dbebbe364f..dfb51d18c5 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -144,7 +144,7 @@ class Weight(Enum): UNIFORM = "uniform" -class Normalisation(Enum): +class Normalization(Enum): """ See also: - :py:class:`monai.networks.nets.ConvNormActi` diff --git a/monai/utils/misc.py b/monai/utils/misc.py index ef688174f1..020884bbcc 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -12,7 +12,6 @@ import collections.abc import itertools import random -import time from ast import literal_eval from distutils.util import strtobool from typing import Any, Callable, Optional, Sequence, Tuple, Union @@ -287,23 +286,3 @@ def dtype_torch_to_numpy(dtype): def dtype_numpy_to_torch(dtype): """Convert a numpy dtype to its torch equivalent.""" return _np_to_torch_dtype[dtype] - - -class PerfContext: - """ - Context manager for tracking how much time is spent within context blocks. This uses `time.perf_counter` to - accumulate the total amount of time in seconds in the attribute `total_time` over however many context blocks - the object is used in. - """ - - def __init__(self): - self.total_time = 0 - self.start_time = None - - def __enter__(self): - self.start_time = time.perf_counter() - return self - - def __exit__(self, exc_type, exc_value, exc_traceback): - self.total_time += time.perf_counter() - self.start_time - self.start_time = None diff --git a/monai/utils/module.py b/monai/utils/module.py index 0edf9047ac..dfd5fb7d7b 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -16,11 +16,14 @@ from re import match from typing import Any, Callable, List, Sequence, Tuple, Union +import torch + from .misc import ensure_tuple OPTIONAL_IMPORT_MSG_FMT = "{}" __all__ = [ + "InvalidPyTorchVersionError", "OptionalImportError", "exact_version", "export", @@ -30,6 +33,8 @@ "get_full_type_name", "has_option", "get_package_version", + "get_torch_version_tuple", + "PT_BEFORE_1_7", ] @@ -105,6 +110,17 @@ def exact_version(the_module, version_str: str = "") -> bool: return bool(the_module.__version__ == version_str) +class InvalidPyTorchVersionError(Exception): + """ + Raised when called function or method requires a more recent + PyTorch version than that installed. + """ + + def __init__(self, required_version, name): + message = f"{name} requires PyTorch version {required_version} or later" + super().__init__(message) + + class OptionalImportError(ImportError): """ Could not import APIs from an optional dependency. @@ -228,10 +244,7 @@ def has_option(obj, keywords: Union[str, Sequence[str]]) -> bool: if not callable(obj): return False sig = inspect.signature(obj) - for key in ensure_tuple(keywords): - if key not in sig.parameters: - return False - return True + return all(key in sig.parameters for key in ensure_tuple(keywords)) def get_package_version(dep_name, default="NOT INSTALLED or UNKNOWN VERSION."): @@ -252,3 +265,22 @@ def get_package_version(dep_name, default="NOT INSTALLED or UNKNOWN VERSION."): del dep del sys.modules[dep_name] return dep_ver + + +def get_torch_version_tuple(): + """ + Returns: + tuple of ints represents the pytorch major/minor version. + """ + return tuple((int(x) for x in torch.__version__.split(".")[:2])) + + +PT_BEFORE_1_7 = True +ver, has_ver = optional_import("pkg_resources", name="parse_version") +try: + if has_ver: + PT_BEFORE_1_7 = ver(torch.__version__) < ver("1.7") + else: + PT_BEFORE_1_7 = get_torch_version_tuple() < (1, 7) +except (AttributeError, TypeError): + pass diff --git a/monai/utils/profiling.py b/monai/utils/profiling.py new file mode 100644 index 0000000000..bcdc0357c4 --- /dev/null +++ b/monai/utils/profiling.py @@ -0,0 +1,110 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from functools import wraps + +import torch + +__all__ = ["torch_profiler_full", "torch_profiler_time_cpu_gpu", "torch_profiler_time_end_to_end", "PerfContext"] + + +def torch_profiler_full(func): + """ + A decorator which will run the torch profiler for the decorated function, + printing the results in full. + Note: Enforces a gpu sync point which could slow down pipelines. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + + with torch.autograd.profiler.profile(use_cuda=True) as prof: + result = func(*args, **kwargs) + + print(prof, flush=True) + + return result + + return wrapper + + +def torch_profiler_time_cpu_gpu(func): + """ + A decorator which measures the execution time of both the CPU and GPU components + of the decorated function, printing both results. + Note: Enforces a gpu sync point which could slow down pipelines. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + + with torch.autograd.profiler.profile(use_cuda=True) as prof: + result = func(*args, **kwargs) + + cpu_time = prof.self_cpu_time_total + gpu_time = sum(evt.self_cuda_time_total for evt in prof.function_events) + + cpu_time = torch.autograd.profiler.format_time(cpu_time) + gpu_time = torch.autograd.profiler.format_time(gpu_time) + + print("cpu time: {}, gpu time: {}".format(cpu_time, gpu_time), flush=True) + + return result + + return wrapper + + +def torch_profiler_time_end_to_end(func): + """ + A decorator which measures the total execution time from when the decorated + function is called to when the last cuda operation finishes, printing the result. + Note: Enforces a gpu sync point which could slow down pipelines. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + + torch.cuda.synchronize() + start = time.perf_counter() + + result = func(*args, **kwargs) + + torch.cuda.synchronize() + end = time.perf_counter() + + total_time = (end - start) * 1e6 + total_time_str = torch.autograd.profiler.format_time(total_time) + print("end to end time: {}".format(total_time_str), flush=True) + + return result + + return wrapper + + +class PerfContext: + """ + Context manager for tracking how much time is spent within context blocks. This uses `time.perf_counter` to + accumulate the total amount of time in seconds in the attribute `total_time` over however many context blocks + the object is used in. + """ + + def __init__(self): + self.total_time = 0 + self.start_time = None + + def __enter__(self): + self.start_time = time.perf_counter() + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.total_time += time.perf_counter() - self.start_time + self.start_time = None diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index 9b22cbdba1..c11bfcfc99 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -28,6 +28,9 @@ SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter") +__all__ = ["make_animated_gif_summary", "add_animated_gif", "add_animated_gif_no_channels", "plot_2d_or_3d_image"] + + def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], scale_factor: float = 1.0) -> Summary: """Function to actually create the animated gif. @@ -76,10 +79,7 @@ def make_animated_gif_summary( if the image data is between 0 and 1, using 255 for this value will scale it to displayable range """ - if max_out == 1: - suffix = "/image" - else: - suffix = "/image/{}" + suffix = "/image" if max_out == 1 else "/image/{}" if other_indices is None: other_indices = {} axis_order = [0] + list(animation_axes) + list(image_axes) @@ -194,9 +194,9 @@ def plot_2d_or_3d_image( dataformats = "CHW" writer.add_image(f"{tag}_{dataformats}", d, step, dataformats=dataformats) return + dataformats = "HW" for j, d2 in enumerate(d[:max_channels]): d2 = rescale_array(d2, 0, 1) - dataformats = "HW" writer.add_image(f"{tag}_{dataformats}_{j}", d2, step, dataformats=dataformats) return diff --git a/requirements-dev.txt b/requirements-dev.txt index fb0c24c859..3de7365d16 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -26,3 +26,7 @@ mypy>=0.790 ninja torchvision psutil +Sphinx==3.3.0 +recommonmark==0.6.0 +sphinx-autodoc-typehints==1.11.1 +sphinx-rtd-theme==0.5.0 diff --git a/setup.cfg b/setup.cfg index 2c11b789c7..78cf8db6ca 100644 --- a/setup.cfg +++ b/setup.cfg @@ -65,13 +65,13 @@ ignore = # N812 lowercase 'torch.nn.functional' imported as non lowercase 'F' N812 per-file-ignores = __init__.py: F401 -exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,_version.py +exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,.venv,_version.py [isort] known_first_party = monai profile = black line_length = 120 -skip = .git, .eggs, venv, versioneer.py, _version.py, conf.py, monai/__init__.py +skip = .git, .eggs, venv, .venv, versioneer.py, _version.py, conf.py, monai/__init__.py skip_glob = *.pyi [versioneer] diff --git a/tests/min_tests.py b/tests/min_tests.py index e22d94bc57..ccfc789992 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -32,6 +32,7 @@ def run_testsuit(): "test_cachedataset", "test_cachedataset_parallel", "test_dataset", + "test_detect_envelope", "test_iterable_dataset", "test_ensemble_evaluator", "test_handler_checkpoint_loader", @@ -40,17 +41,20 @@ def run_testsuit(): "test_handler_lr_scheduler", "test_handler_confusion_matrix", "test_handler_confusion_matrix_dist", + "test_handler_hausdorff_distance", "test_handler_mean_dice", "test_handler_rocauc", "test_handler_rocauc_dist", "test_handler_segmentation_saver", "test_handler_smartcache", "test_handler_stats", + "test_handler_surface_distance", "test_handler_tb_image", "test_handler_tb_stats", "test_handler_validation", "test_hausdorff_distance", "test_header_correct", + "test_hilbert_transform", "test_img2tensorboard", "test_integration_segmentation_3d", "test_integration_sliding_window", diff --git a/tests/test_add_extreme_points_channel.py b/tests/test_add_extreme_points_channel.py new file mode 100644 index 0000000000..f4f3fa6d02 --- /dev/null +++ b/tests/test_add_extreme_points_channel.py @@ -0,0 +1,67 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import AddExtremePointsChannel + +IMG_CHANNEL = 3 + +TEST_CASE_1 = [ + { + "img": np.zeros((IMG_CHANNEL, 4, 3)), + "label": np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]]), + "sigma": 1.0, + "rescale_min": 0.0, + "rescale_max": 1.0, + }, + np.array( + [ + [0.38318458, 0.98615628, 0.85551184], + [0.35422316, 0.94430935, 1.0], + [0.46000731, 0.57319659, 0.46000722], + [0.64577687, 0.38318464, 0.0], + ] + ), +] + +TEST_CASE_2 = [ + { + "img": np.zeros((IMG_CHANNEL, 4, 3)), + "label": np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]]), + "sigma": 1.0, + "rescale_min": 0.0, + "rescale_max": 1.0, + }, + np.array( + [ + [0.44628328, 0.80495411, 0.44628328], + [0.6779086, 1.0, 0.67790854], + [0.33002687, 0.62079221, 0.33002687], + [0.0, 0.31848389, 0.0], + ] + ), +] + + +class TestAddExtremePointsChannel(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_correct_results(self, input_data, expected): + add_extreme_points_channel = AddExtremePointsChannel() + result = add_extreme_points_channel(**input_data) + np.testing.assert_allclose(result[IMG_CHANNEL], expected, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_add_extreme_points_channeld.py b/tests/test_add_extreme_points_channeld.py new file mode 100644 index 0000000000..4fee176b20 --- /dev/null +++ b/tests/test_add_extreme_points_channeld.py @@ -0,0 +1,57 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import AddExtremePointsChanneld + +IMG_CHANNEL = 3 + +TEST_CASE_1 = [ + {"img": np.zeros((IMG_CHANNEL, 4, 3)), "label": np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]])}, + np.array( + [ + [0.38318458, 0.98615628, 0.85551184], + [0.35422316, 0.94430935, 1.0], + [0.46000731, 0.57319659, 0.46000722], + [0.64577687, 0.38318464, 0.0], + ] + ), +] + +TEST_CASE_2 = [ + {"img": np.zeros((IMG_CHANNEL, 4, 3)), "label": np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]])}, + np.array( + [ + [0.44628328, 0.80495411, 0.44628328], + [0.6779086, 1.0, 0.67790854], + [0.33002687, 0.62079221, 0.33002687], + [0.0, 0.31848389, 0.0], + ] + ), +] + + +class TestAddExtremePointsChanneld(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_correct_results(self, input_data, expected): + add_extreme_points_channel = AddExtremePointsChanneld( + keys="img", label_key="label", sigma=1.0, rescale_min=0.0, rescale_max=1.0 + ) + result = add_extreme_points_channel(input_data) + np.testing.assert_allclose(result["img"][IMG_CHANNEL], expected, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_autoencoder.py b/tests/test_autoencoder.py index a7749d7f3a..86b31e0361 100644 --- a/tests/test_autoencoder.py +++ b/tests/test_autoencoder.py @@ -62,6 +62,15 @@ CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3] +TEST_CASE_FAIL = { # 2-channel 2D, should fail because of stride/channel mismatch. + "dimensions": 2, + "in_channels": 2, + "out_channels": 2, + "channels": (4, 8, 16), + "strides": (2, 2), +} + + class TestAutoEncoder(unittest.TestCase): @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): @@ -76,6 +85,10 @@ def test_script(self): test_data = torch.randn(2, 1, 32, 32) test_script_save(net, test_data) + def test_channel_stride_difference(self): + with self.assertRaises(ValueError): + net = AutoEncoder(**TEST_CASE_FAIL) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_bilateral_approx_cpu.py b/tests/test_bilateral_approx_cpu.py new file mode 100644 index 0000000000..13aaaeb34e --- /dev/null +++ b/tests/test_bilateral_approx_cpu.py @@ -0,0 +1,381 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers.filtering import BilateralFilter +from tests.utils import skip_if_no_cpp_extention + +TEST_CASES = [ + [ + # Case Descirption + "1 dimension, 1 channel, low spatial sigma, low color sigma", + # Spatial and Color Sigmas + (1, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [1.000000, 0.000000, 0.000000, 0.000000, 1.000000] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000000, 1.000000, 0.000000, 0.000000] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, low spatial sigma, high color sigma", + # Spatial and Color Sigmas + (1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.631360, 0.099349, 0.070177, 0.164534, 0.649869] + ], + # Batch 1 + [ + # Channel 0 + [0.052271, 0.173599, 0.481337, 0.183721, 0.045619] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, high spatial sigma, low color sigma", + # Spatial and Color Sigmas + (4, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [1.000000, 0.000000, 0.000000, 0.000000, 1.000000] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000000, 1.000000, 0.000000, 0.000000] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.497667, 0.268683, 0.265026, 0.261467, 0.495981] + ], + # Batch 1 + [ + # Channel 0 + [0.145959, 0.142282, 0.315710, 0.135609, 0.132572] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 4 channel, low spatial sigma, high color sigma", + # Spatial and Color Sigmas + (1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 0], + # Channel 1 + [1, 0, 1, 0, 0], + # Channel 2 + [0, 0, 1, 0, 1], + # Channel 3 + [0, 0, 0, 0, 1], + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.960843, 0.073540, 0.027689, 0.002676, 0.000000], + # Channel 1 + [0.960843, 0.073540, 0.951248, 0.003033, 0.000750], + # Channel 2 + [0.000000, 0.000000, 0.923559, 0.000357, 0.981324], + # Channel 3 + [0.000000, 0.000000, 0.000000, 0.000000, 0.980574], + ] + ], + ], + [ + # Case Descirption + "2 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]] + ], + # Batch 1 + [ + # Channel 0 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.213684, 0.094356, 0.092973, 0.091650, 0.216281], + [0.094085, 0.092654, 0.091395, 0.090186, 0.089302], + [0.092436, 0.091150, 0.090008, 0.088896, 0.088897], + [0.090849, 0.089717, 0.088759, 0.087751, 0.088501], + [0.211458, 0.088334, 0.087495, 0.087049, 0.212173], + ] + ], + # Batch 1 + [ + # Channel 0 + [ + [0.033341, 0.031314, 0.029367, 0.027494, 0.025692], + [0.031869, 0.030632, 0.028820, 0.027074, 0.025454], + [0.030455, 0.029628, 0.084257, 0.026704, 0.025372], + [0.029095, 0.028391, 0.027790, 0.026375, 0.025292], + [0.027786, 0.027197, 0.026692, 0.026181, 0.025213], + ] + ], + ], + ], + [ + # Case Descirption + "2 dimension, 4 channel, high spatial sigma, high color sigma", + # Spatial and Color Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]], + # Channel 1 + [[1, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 1]], + # Channel 2 + [[0, 0, 1, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 0]], + # Channel 3 + [[0, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0]], + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.244373, 0.014488, 0.036589, 0.014226, 0.024329], + [0.014108, 0.014228, 0.014096, 0.013961, 0.013823], + [0.013574, 0.013757, 0.013836, 0.013699, 0.013558], + [0.013008, 0.013211, 0.013404, 0.013438, 0.013295], + [0.025179, 0.012634, 0.034555, 0.013050, 0.237582], + ], + # Channel 1 + [ + [0.271496, 0.015547, 0.439432, 0.015700, 0.089579], + [0.015252, 0.015702, 0.015779, 0.015859, 0.015940], + [0.015020, 0.015556, 0.015935, 0.016015, 0.016098], + [0.014774, 0.015331, 0.015860, 0.016171, 0.016255], + [0.107384, 0.015094, 0.462471, 0.016166, 0.263480], + ], + # Channel 2 + [ + [0.027123, 0.003527, 0.467273, 0.004912, 0.645776], + [0.003810, 0.004908, 0.005605, 0.006319, 0.007050], + [0.004816, 0.005991, 0.006989, 0.007716, 0.008459], + [0.005880, 0.007060, 0.008179, 0.009101, 0.009858], + [0.633398, 0.008191, 0.496893, 0.010376, 0.025898], + ], + # Channel 3 + [ + [0.000000, 0.002468, 0.064430, 0.003437, 0.580526], + [0.002666, 0.003434, 0.003922, 0.004422, 0.004933], + [0.003370, 0.004192, 0.004890, 0.005399, 0.005919], + [0.004115, 0.004940, 0.005723, 0.006368, 0.006898], + [0.551194, 0.005731, 0.068977, 0.007260, 0.000000], + ], + ] + ], + ], + [ + # Case Descirption + "3 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + # Frame 1 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 2 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 3 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 4 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + ] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [ + [0.086801, 0.036670, 0.035971, 0.035304, 0.088456], + [0.036639, 0.035652, 0.035009, 0.034394, 0.033803], + [0.035899, 0.034897, 0.034136, 0.033566, 0.033129], + [0.035180, 0.034238, 0.033413, 0.032811, 0.032577], + [0.088290, 0.033597, 0.032821, 0.032134, 0.088786], + ], + # Frame 1 + [ + [0.036286, 0.035269, 0.034632, 0.034021, 0.033435], + [0.035398, 0.034485, 0.033922, 0.033381, 0.033177], + [0.034688, 0.033822, 0.033169, 0.032664, 0.032780], + [0.034024, 0.033234, 0.032533, 0.032005, 0.032388], + [0.033564, 0.032797, 0.032118, 0.031525, 0.032105], + ], + # Frame 2 + [ + [0.035225, 0.034169, 0.033404, 0.032843, 0.032766], + [0.034383, 0.033487, 0.032908, 0.032415, 0.032650], + [0.033691, 0.032921, 0.032353, 0.031900, 0.032384], + [0.033080, 0.032390, 0.031786, 0.031432, 0.032008], + [0.033099, 0.032373, 0.031737, 0.031479, 0.032054], + ], + # Frame 3 + [ + [0.034216, 0.033231, 0.032337, 0.031758, 0.032101], + [0.033456, 0.032669, 0.031913, 0.031455, 0.032034], + [0.032788, 0.032140, 0.031618, 0.031413, 0.031977], + [0.032221, 0.031650, 0.031145, 0.031130, 0.031652], + [0.032642, 0.031968, 0.031378, 0.031433, 0.032003], + ], + # Frame 4 + [ + [0.086207, 0.032335, 0.031499, 0.030832, 0.087498], + [0.032570, 0.031884, 0.031155, 0.030858, 0.031401], + [0.031967, 0.031417, 0.030876, 0.030881, 0.031388], + [0.031602, 0.031103, 0.030696, 0.030960, 0.031455], + [0.090599, 0.031546, 0.031127, 0.031386, 0.083483], + ], + ] + ] + ], + ], +] + + +@skip_if_no_cpp_extention +class BilateralFilterTestCaseCpuApprox(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cpu_approx(self, test_case_description, sigmas, input, expected): + + # Params to determine the implementation to test + device = torch.device("cpu") + fast_approx = True + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device) + output = BilateralFilter.apply(input_tensor, *sigmas, fast_approx).cpu().numpy() + + # Ensure result are as expected + np.testing.assert_allclose(output, expected, atol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_bilateral_approx_cuda.py b/tests/test_bilateral_approx_cuda.py new file mode 100644 index 0000000000..5ea0d997d1 --- /dev/null +++ b/tests/test_bilateral_approx_cuda.py @@ -0,0 +1,386 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers.filtering import BilateralFilter +from tests.utils import skip_if_no_cpp_extention, skip_if_no_cuda + +TEST_CASES = [ + [ + # Case Descirption + "1 dimension, 1 channel, low spatial sigma, low color sigma", + # Spatial and Color Sigmas + (1, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [1.000000, 0.000000, 0.000000, 0.000000, 1.000000] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000000, 1.000000, 0.000000, 0.000000] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, low spatial sigma, high color sigma", + # Spatial and Color Sigmas + (1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.880626, 0.306148, 0.158734, 0.164534, 0.754386] + ], + # Batch 1 + [ + # Channel 0 + [0.019010, 0.104507, 0.605634, 0.183721, 0.045619] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, high spatial sigma, low color sigma", + # Spatial and Color Sigmas + (4, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [1.000000, 0.000000, 0.000000, 0.000000, 1.000000] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000000, 1.000000, 0.000000, 0.000000] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.497667, 0.268683, 0.265026, 0.261467, 0.495981] + ], + # Batch 1 + [ + # Channel 0 + [0.149889, 0.148226, 0.367978, 0.144023, 0.141317] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 4 channel, low spatial sigma, high color sigma", + # Spatial and Color Sigmas + (1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 0], + # Channel 1 + [1, 0, 1, 0, 0], + # Channel 2 + [0, 0, 1, 0, 1], + # Channel 3 + [0, 0, 0, 0, 1], + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.988107, 0.061340, 0.001565, 0.000011, 0.000000], + # Channel 1 + [0.988107, 0.061340, 0.998000, 0.000016, 0.000123], + # Channel 2 + [0.000000, 0.000000, 0.996435, 0.000006, 0.999236], + # Channel 3 + [0.000000, 0.000000, 0.000000, 0.000000, 0.999113], + ] + ], + ], + [ + # Case Descirption + "2 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]] + ], + # Batch 1 + [ + # Channel 0 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.211469, 0.094356, 0.092973, 0.091650, 0.211894], + [0.093755, 0.091753, 0.090524, 0.089343, 0.088384], + [0.091803, 0.089783, 0.088409, 0.087346, 0.086927], + [0.089938, 0.088126, 0.086613, 0.085601, 0.085535], + [0.208359, 0.086535, 0.085179, 0.084210, 0.205858], + ] + ], + # Batch 1 + [ + # Channel 0 + [ + [0.032760, 0.030146, 0.027442, 0.024643, 0.021744], + [0.030955, 0.029416, 0.026574, 0.023629, 0.020841], + [0.028915, 0.027834, 0.115442, 0.022515, 0.020442], + [0.026589, 0.025447, 0.024319, 0.021286, 0.019964], + [0.023913, 0.022704, 0.021510, 0.020388, 0.019379], + ] + ], + ], + ], + [ + # Case Descirption + "2 dimension, 4 channel, high spatial sigma, high color sigma", + # Spatial and Color Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]], + # Channel 1 + [[1, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 1]], + # Channel 2 + [[0, 0, 1, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 0]], + # Channel 3 + [[0, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0]], + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.557349, 0.011031, 0.001800, 0.011265, 0.000631], + [0.009824, 0.010361, 0.010429, 0.010506, 0.010595], + [0.008709, 0.009252, 0.009688, 0.009714, 0.009744], + [0.007589, 0.008042, 0.008576, 0.008887, 0.008852], + [0.000420, 0.006827, 0.001048, 0.007763, 0.190722], + ], + # Channel 1 + [ + [0.614072, 0.011045, 0.925766, 0.011287, 0.007548], + [0.009838, 0.010382, 0.010454, 0.010536, 0.010630], + [0.008727, 0.009277, 0.009720, 0.009751, 0.009787], + [0.007611, 0.008071, 0.008613, 0.008932, 0.008904], + [0.027088, 0.006859, 0.950749, 0.007815, 0.230270], + ], + # Channel 2 + [ + [0.056723, 0.000150, 0.973790, 0.000233, 0.990814], + [0.000151, 0.000214, 0.000257, 0.000307, 0.000364], + [0.000186, 0.000257, 0.000328, 0.000384, 0.000449], + [0.000221, 0.000295, 0.000382, 0.000465, 0.000538], + [0.993884, 0.000333, 0.984743, 0.000532, 0.039548], + ], + # Channel 3 + [ + [0.000000, 0.000136, 0.049824, 0.000210, 0.983897], + [0.000136, 0.000193, 0.000232, 0.000277, 0.000329], + [0.000168, 0.000232, 0.000297, 0.000347, 0.000405], + [0.000200, 0.000266, 0.000345, 0.000420, 0.000485], + [0.967217, 0.000301, 0.035041, 0.000481, 0.000000], + ], + ] + ], + ], + [ + # Case Descirption + "3 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + # Frame 1 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 2 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 3 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 4 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + ] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [ + [0.085451, 0.037820, 0.036880, 0.035978, 0.084296], + [0.037939, 0.036953, 0.036155, 0.035385, 0.034640], + [0.037167, 0.036302, 0.035603, 0.034931, 0.034465], + [0.036469, 0.035724, 0.035137, 0.034572, 0.034480], + [0.088942, 0.035193, 0.034682, 0.034266, 0.090568], + ], + # Frame 1 + [ + [0.037125, 0.035944, 0.035103, 0.033429, 0.033498], + [0.033380, 0.032653, 0.033748, 0.033073, 0.032549], + [0.034834, 0.034001, 0.033500, 0.032902, 0.032560], + [0.033972, 0.033554, 0.033220, 0.032765, 0.032570], + [0.033590, 0.033222, 0.032927, 0.032689, 0.032629], + ], + # Frame 2 + [ + [0.035635, 0.034468, 0.033551, 0.032818, 0.032302], + [0.034523, 0.032830, 0.032146, 0.031536, 0.031149], + [0.033612, 0.032011, 0.031664, 0.031128, 0.030839], + [0.032801, 0.031668, 0.031529, 0.031198, 0.030978], + [0.032337, 0.031550, 0.031419, 0.031383, 0.031211], + ], + # Frame 3 + [ + [0.034300, 0.033236, 0.032239, 0.031517, 0.031133], + [0.033357, 0.031842, 0.031035, 0.030471, 0.030126], + [0.032563, 0.031094, 0.030156, 0.029703, 0.029324], + [0.031850, 0.030505, 0.030027, 0.029802, 0.029461], + [0.031555, 0.030121, 0.029943, 0.030000, 0.029700], + ], + # Frame 4 + [ + [0.083156, 0.032122, 0.031204, 0.030380, 0.080582], + [0.032296, 0.030936, 0.030170, 0.029557, 0.029124], + [0.031617, 0.030293, 0.029377, 0.028886, 0.028431], + [0.031084, 0.029859, 0.028839, 0.028439, 0.027973], + [0.164616, 0.029457, 0.028484, 0.028532, 0.211082], + ], + ] + ] + ], + ], +] + + +@skip_if_no_cuda +@skip_if_no_cpp_extention +class BilateralFilterTestCaseCudaApprox(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cuda_approx(self, test_case_description, sigmas, input, expected): + + # Skip this test + if not torch.cuda.is_available(): + return + + # Params to determine the implementation to test + device = torch.device("cuda") + fast_approx = True + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device) + output = BilateralFilter.apply(input_tensor, *sigmas, fast_approx).cpu().numpy() + + # Ensure result are as expected + np.testing.assert_allclose(output, expected, atol=1e-2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_bilateral_precise.py b/tests/test_bilateral_precise.py new file mode 100644 index 0000000000..f2a265b106 --- /dev/null +++ b/tests/test_bilateral_precise.py @@ -0,0 +1,403 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers.filtering import BilateralFilter +from tests.utils import skip_if_no_cpp_extention, skip_if_no_cuda + +TEST_CASES = [ + [ + # Case Descirption + "1 dimension, 1 channel, low spatial sigma, low color sigma", + # Spatial and Color Sigmas + (1, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.999998, 0.000002, 0.000000, 0.000002, 0.999998] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000001, 0.999995, 0.000001, 0.000000] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, low spatial sigma, high color sigma", + # Spatial and Color Sigmas + (1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.813183, 0.186817, 0.061890, 0.186817, 0.813183] + ], + # Batch 1 + [ + # Channel 0 + [0.030148, 0.148418, 0.555452, 0.148418, 0.030148] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, high spatial sigma, low color sigma", + # Spatial and Color Sigmas + (4, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.999999, 0.000009, 0.000009, 0.000009, 0.999999] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000000, 0.999967, 0.000000, 0.000000] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.839145, 0.572834, 0.562460, 0.572834, 0.839145] + ], + # Batch 1 + [ + # Channel 0 + [0.049925, 0.055062, 0.171732, 0.055062, 0.049925] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 4 channel, low spatial sigma, high color sigma", + # Spatial and Color Sigmas + (1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 0], + # Channel 1 + [1, 0, 1, 0, 0], + # Channel 2 + [0, 0, 1, 0, 1], + # Channel 3 + [0, 0, 0, 0, 1], + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.889742, 0.141296, 0.027504, 0.000000, 0.000000], + # Channel 1 + [0.909856, 0.256817, 0.725970, 0.115520, 0.020114], + # Channel 2 + [0.020114, 0.115520, 0.725970, 0.256817, 0.909856], + # Channel 3 + [0.000000, 0.000000, 0.027504, 0.141296, 0.889742], + ] + ], + ], + [ + # Case Descirption + "2 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]] + ], + # Batch 1 + [ + # Channel 0 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.688943, 0.374599, 0.368574, 0.374599, 0.688943], + [0.374599, 0.358248, 0.352546, 0.358248, 0.374599], + [0.368574, 0.352546, 0.346955, 0.352546, 0.368574], + [0.374599, 0.358248, 0.352546, 0.358248, 0.374599], + [0.688943, 0.374599, 0.368574, 0.374599, 0.688943], + ] + ], + # Batch 1 + [ + # Channel 0 + [ + [0.004266, 0.004687, 0.004836, 0.004687, 0.004266], + [0.004687, 0.005150, 0.005314, 0.005150, 0.004687], + [0.004836, 0.005314, 0.018598, 0.005314, 0.004836], + [0.004687, 0.005150, 0.005314, 0.005150, 0.004687], + [0.004266, 0.004687, 0.004836, 0.004687, 0.004266], + ] + ], + ], + ], + [ + # Case Descirption + "2 dimension, 4 channel, high spatial sigma, high color sigma", + # Spatial and Color Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]], + # Channel 1 + [[1, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 1]], + # Channel 2 + [[0, 0, 1, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 0]], + # Channel 3 + [[0, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0]], + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.692549, 0.149979, 0.220063, 0.115840, 0.035799], + [0.148403, 0.133935, 0.123253, 0.116828, 0.114623], + [0.128773, 0.122804, 0.120731, 0.122804, 0.128773], + [0.114623, 0.116828, 0.123253, 0.133935, 0.148403], + [0.035799, 0.115840, 0.220063, 0.149979, 0.692549], + ], + # Channel 1 + [ + [0.731597, 0.186319, 0.436069, 0.152181, 0.074847], + [0.180049, 0.168217, 0.158453, 0.151110, 0.146269], + [0.159760, 0.156381, 0.155211, 0.156381, 0.159760], + [0.146269, 0.151110, 0.158453, 0.168217, 0.180049], + [0.074847, 0.152181, 0.436068, 0.186319, 0.731597], + ], + # Channel 2 + [ + [0.074847, 0.152181, 0.436068, 0.186319, 0.731597], + [0.146269, 0.151110, 0.158453, 0.168217, 0.180049], + [0.159760, 0.156381, 0.155211, 0.156381, 0.159760], + [0.180049, 0.168217, 0.158453, 0.151110, 0.146269], + [0.731597, 0.186319, 0.436069, 0.152181, 0.074847], + ], + # Channel 3 + [ + [0.035799, 0.115840, 0.220063, 0.149979, 0.692549], + [0.114623, 0.116828, 0.123253, 0.133935, 0.148403], + [0.128773, 0.122804, 0.120731, 0.122804, 0.128773], + [0.148403, 0.133935, 0.123253, 0.116828, 0.114623], + [0.692549, 0.149979, 0.220063, 0.115840, 0.035799], + ], + ] + ], + ], + [ + # Case Descirption + "3 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + # Frame 1 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 2 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 3 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 4 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + ] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [ + [0.554430, 0.254995, 0.251207, 0.254996, 0.554430], + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + [0.251207, 0.241082, 0.237534, 0.241082, 0.251207], + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + [0.554430, 0.254995, 0.251207, 0.254996, 0.554430], + ], + # Frame 1 + [ + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + [0.244692, 0.234873, 0.231432, 0.234873, 0.244692], + [0.241082, 0.231431, 0.228049, 0.231432, 0.241082], + [0.244692, 0.234873, 0.231432, 0.234873, 0.244692], + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + ], + # Frame 2 + [ + [0.251207, 0.241081, 0.237534, 0.241082, 0.251207], + [0.241082, 0.231431, 0.228049, 0.231432, 0.241082], + [0.237534, 0.228048, 0.224724, 0.228049, 0.237534], + [0.241082, 0.231431, 0.228049, 0.231432, 0.241082], + [0.251207, 0.241081, 0.237534, 0.241082, 0.251207], + ], + # Frame 3 + [ + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + [0.244692, 0.234873, 0.231432, 0.234873, 0.244692], + [0.241082, 0.231431, 0.228049, 0.231432, 0.241082], + [0.244692, 0.234873, 0.231432, 0.234873, 0.244692], + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + ], + # Frame 4 + [ + [0.554430, 0.254995, 0.251207, 0.254996, 0.554430], + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + [0.251207, 0.241082, 0.237534, 0.241082, 0.251207], + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + [0.554430, 0.254995, 0.251207, 0.254996, 0.554430], + ], + ] + ] + ], + ], +] + + +@skip_if_no_cpp_extention +class BilateralFilterTestCaseCpuPrecised(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cpu_precised(self, test_case_description, sigmas, input, expected): + + # Params to determine the implementation to test + device = torch.device("cpu") + fast_approx = False + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device) + output = BilateralFilter.apply(input_tensor, *sigmas, fast_approx).cpu().numpy() + + # Ensure result are as expected + np.testing.assert_allclose(output, expected, atol=1e-5) + + +@skip_if_no_cuda +@skip_if_no_cpp_extention +class BilateralFilterTestCaseCudaPrecised(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cuda_precised(self, test_case_description, sigmas, input, expected): + + # Skip this test + if not torch.cuda.is_available(): + return + + # Params to determine the implementation to test + device = torch.device("cuda") + fast_approx = False + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device) + output = BilateralFilter.apply(input_tensor, *sigmas, fast_approx).cpu().numpy() + + # Ensure result are as expected + np.testing.assert_allclose(output, expected, atol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_bounding_rect.py b/tests/test_bounding_rect.py new file mode 100644 index 0000000000..69476479a3 --- /dev/null +++ b/tests/test_bounding_rect.py @@ -0,0 +1,49 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +import monai +from monai.transforms import BoundingRect + +TEST_CASE_1 = [(2, 3), [[-1, -1], [1, 2]]] + +TEST_CASE_2 = [(1, 8, 10), [[0, 7, 1, 9]]] + +TEST_CASE_3 = [(2, 16, 20, 18), [[0, 16, 0, 20, 0, 18], [0, 16, 0, 20, 0, 18]]] + + +class TestBoundingRect(unittest.TestCase): + def setUp(self): + monai.utils.set_determinism(1) + + def tearDown(self): + monai.utils.set_determinism(None) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, input_shape, expected): + test_data = np.random.randint(0, 8, size=input_shape) + test_data = test_data == 7 + result = BoundingRect()(test_data) + np.testing.assert_allclose(result, expected) + + def test_select_fn(self): + test_data = np.random.randint(0, 8, size=(2, 3)) + test_data = test_data == 7 + bbox = BoundingRect(select_fn=lambda x: x < 1)(test_data) + np.testing.assert_allclose(bbox, [[0, 3], [0, 3]]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_bounding_rectd.py b/tests/test_bounding_rectd.py new file mode 100644 index 0000000000..c33a3c371d --- /dev/null +++ b/tests/test_bounding_rectd.py @@ -0,0 +1,49 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +import monai +from monai.transforms import BoundingRectD + +TEST_CASE_1 = [(2, 3), [[-1, -1], [1, 2]]] + +TEST_CASE_2 = [(1, 8, 10), [[0, 7, 1, 9]]] + +TEST_CASE_3 = [(2, 16, 20, 18), [[0, 16, 0, 20, 0, 18], [0, 16, 0, 20, 0, 18]]] + + +class TestBoundingRectD(unittest.TestCase): + def setUp(self): + monai.utils.set_determinism(1) + + def tearDown(self): + monai.utils.set_determinism(None) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, input_shape, expected): + test_data = np.random.randint(0, 8, size=input_shape) + test_data = test_data == 7 + result = BoundingRectD("image")({"image": test_data}) + np.testing.assert_allclose(result["image_bbox"], expected) + + result = BoundingRectD("image", "cc")({"image": test_data}) + np.testing.assert_allclose(result["image_cc"], expected) + + with self.assertRaises(KeyError): + BoundingRectD("image", "cc")({"image": test_data, "image_cc": None}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_channel_pad.py b/tests/test_channel_pad.py new file mode 100644 index 0000000000..00d0eab65a --- /dev/null +++ b/tests/test_channel_pad.py @@ -0,0 +1,48 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.layers import ChannelPad + +TEST_CASES_3D = [] +for type_1 in ("pad", "project"): + input_shape = (16, 10, 32, 24, 48) + out_chns = 13 + result_shape = list(input_shape) + result_shape[1] = out_chns + test_case = [ + {"spatial_dims": 3, "in_channels": 10, "out_channels": out_chns, "mode": type_1}, + input_shape, + result_shape, + ] + TEST_CASES_3D.append(test_case) + + +class TestChannelPad(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D) + def test_shape(self, input_param, input_shape, expected_shape): + net = ChannelPad(**input_param) + net.eval() + with torch.no_grad(): + result = net(torch.randn(input_shape)) + self.assertEqual(list(result.shape), list(expected_shape)) + + def test_wrong_mode(self): + with self.assertRaises(ValueError): + ChannelPad(3, 10, 20, mode="test")(torch.randn(10, 10)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_compute_occlusion_sensitivity.py b/tests/test_compute_occlusion_sensitivity.py index 897177c6ed..9f30162c47 100644 --- a/tests/test_compute_occlusion_sensitivity.py +++ b/tests/test_compute_occlusion_sensitivity.py @@ -43,6 +43,7 @@ "label": 0, "b_box": [-1, -1, 2, 3, -1, -1, -1, -1], "n_batch": 10, + "stride": 2, }, (2, 6, 6), ] diff --git a/tests/test_convert_to_multi_channeld.py b/tests/test_convert_to_multi_channeld.py new file mode 100644 index 0000000000..2de3ee7394 --- /dev/null +++ b/tests/test_convert_to_multi_channeld.py @@ -0,0 +1,34 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import ConvertToMultiChannelBasedOnBratsClassesd + +TEST_CASE = [ + {"keys": "label"}, + {"label": np.array([[0, 1, 2], [1, 2, 4], [0, 1, 4]])}, + np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]), +] + + +class TestConvertToMultiChanneld(unittest.TestCase): + @parameterized.expand([TEST_CASE]) + def test_type_shape(self, keys, data, expected_result): + result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data) + np.testing.assert_equal(result["label"], expected_result) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_detect_envelope.py b/tests/test_detect_envelope.py new file mode 100644 index 0000000000..cbd281f6e8 --- /dev/null +++ b/tests/test_detect_envelope.py @@ -0,0 +1,164 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import DetectEnvelope +from monai.utils import InvalidPyTorchVersionError, OptionalImportError +from tests.utils import SkipIfAtLeastPyTorchVersion, SkipIfBeforePyTorchVersion, SkipIfModule, SkipIfNoModule + +n_samples = 500 +hann_windowed_sine = np.sin(2 * np.pi * 10 * np.linspace(0, 1, n_samples)) * np.hanning(n_samples) + +# SINGLE-CHANNEL VALUE TESTS +# using np.expand_dims() to add length 1 channel dimension at dimension 0 + +TEST_CASE_1D_SINE = [ + {}, # args (empty, so use default) + np.expand_dims(hann_windowed_sine, 0), # Input data: Hann windowed sine wave + np.expand_dims(np.hanning(n_samples), 0), # Expected output: the Hann window + 1e-4, # absolute tolerance +] + +TEST_CASE_2D_SINE = [ + {}, # args (empty, so use default (i.e. process along first spatial dimension, axis=1) + # Create 10 identical windowed sine waves as a 2D numpy array + np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0), + # Expected output: Set of 10 identical Hann windows + np.expand_dims(np.stack([np.hanning(n_samples)] * 10, axis=1), 0), + 1e-4, # absolute tolerance +] + +TEST_CASE_3D_SINE = [ + {}, # args (empty, so use default (i.e. process along first spatial dimension, axis=1) + # Create 100 identical windowed sine waves as a (n_samples x 10 x 10) 3D numpy array + np.expand_dims(np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), 0), + # Expected output: Set of 100 identical Hann windows in (n_samples x 10 x 10) 3D numpy array + np.expand_dims(np.stack([np.stack([np.hanning(n_samples)] * 10, axis=1)] * 10, axis=2), 0), + 1e-4, # absolute tolerance +] + +TEST_CASE_2D_SINE_AXIS_1 = [ + {"axis": 2}, # set axis argument to 1 + # Create 10 identical windowed sine waves as a 2D numpy array + np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0), + # Expected output: absolute value of each sample of the waveform, repeated (i.e. flat envelopes) + np.expand_dims(np.abs(np.repeat(hann_windowed_sine, 10).reshape((n_samples, 10))), 0), + 1e-4, # absolute tolerance +] + +TEST_CASE_1D_SINE_PADDING_N = [ + {"n": 512}, # args (empty, so use default) + np.expand_dims(hann_windowed_sine, 0), # Input data: Hann windowed sine wave + np.expand_dims(np.concatenate([np.hanning(500), np.zeros(12)]), 0), # Expected output: the Hann window + 1e-3, # absolute tolerance +] + +# MULTI-CHANNEL VALUE TEST + +TEST_CASE_2_CHAN_3D_SINE = [ + {}, # args (empty, so use default (i.e. process along first spatial dimension, axis=1) + # Create 100 identical windowed sine waves as a (n_samples x 10 x 10) 3D numpy array, twice (2 channels) + np.stack([np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2)] * 2, axis=0), + # Expected output: Set of 100 identical Hann windows in (n_samples x 10 x 10) 3D numpy array, twice (2 channels) + np.stack([np.stack([np.stack([np.hanning(n_samples)] * 10, axis=1)] * 10, axis=2)] * 2, axis=0), + 1e-4, # absolute tolerance +] + +# EXCEPTION TESTS + +TEST_CASE_INVALID_AXIS_1 = [ + {"axis": 3}, # set axis argument to 3 when only 3 dimensions (1 channel + 2 spatial) + np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0), # Create 2D dataset + "__call__", # method expected to raise exception +] + +TEST_CASE_INVALID_AXIS_2 = [ + {"axis": -1}, # set axis argument negative + np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0), # Create 2D dataset + "__init__", # method expected to raise exception +] + +TEST_CASE_INVALID_N = [ + {"n": 0}, # set FFT length to zero + np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0), # Create 2D dataset + "__call__", # method expected to raise exception +] + +TEST_CASE_INVALID_DTYPE = [ + {}, + np.expand_dims(np.array(hann_windowed_sine, dtype=np.complex), 0), # complex numbers are invalid + "__call__", # method expected to raise exception +] + +TEST_CASE_INVALID_IMG_LEN = [ + {}, + np.expand_dims(np.array([]), 0), # empty array is invalid + "__call__", # method expected to raise exception +] + +TEST_CASE_INVALID_OBJ = [{}, "a string", "__call__"] # method expected to raise exception + + +@SkipIfBeforePyTorchVersion((1, 7)) +@SkipIfNoModule("torch.fft") +class TestDetectEnvelope(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_1D_SINE, + TEST_CASE_2D_SINE, + TEST_CASE_3D_SINE, + TEST_CASE_2D_SINE_AXIS_1, + TEST_CASE_1D_SINE_PADDING_N, + TEST_CASE_2_CHAN_3D_SINE, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = DetectEnvelope(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + @parameterized.expand( + [ + TEST_CASE_INVALID_AXIS_1, + TEST_CASE_INVALID_AXIS_2, + TEST_CASE_INVALID_N, + TEST_CASE_INVALID_DTYPE, + TEST_CASE_INVALID_IMG_LEN, + ] + ) + def test_value_error(self, arguments, image, method): + if method == "__init__": + self.assertRaises(ValueError, DetectEnvelope, **arguments) + elif method == "__call__": + self.assertRaises(ValueError, DetectEnvelope(**arguments), image) + else: + raise ValueError("Expected raising method invalid. Should be __init__ or __call__.") + + +@SkipIfBeforePyTorchVersion((1, 7)) +@SkipIfModule("torch.fft") +class TestHilbertTransformNoFFTMod(unittest.TestCase): + def test_no_fft_module_error(self): + self.assertRaises(OptionalImportError, DetectEnvelope(), np.random.rand(1, 10)) + + +@SkipIfAtLeastPyTorchVersion((1, 7)) +class TestDetectEnvelopeInvalidPyTorch(unittest.TestCase): + def test_invalid_pytorch_error(self): + with self.assertRaisesRegexp(InvalidPyTorchVersionError, "version"): + DetectEnvelope() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index ca5e056a16..6b89c8c4fd 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -16,8 +16,7 @@ from parameterized import parameterized from monai.networks.nets import DynUNet - -# from tests.utils import test_script_save +from tests.utils import test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -111,14 +110,13 @@ def test_shape(self, input_param, input_shape, expected_shape): net.eval() with torch.no_grad(): result = net(torch.randn(input_shape).to(device)) - self.assertEqual(result.shape, expected_shape) - + self.assertEqual(result[0].shape, expected_shape) -# def test_script(self): -# input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] -# net = DynUNet(**input_param) -# test_data = torch.randn(input_shape) -# test_script_save(net, test_data) + def test_script(self): + input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] + net = DynUNet(**input_param) + test_data = torch.randn(input_shape) + test_script_save(net, test_data) class TestDynUNetDeepSupervision(unittest.TestCase): diff --git a/tests/test_generate_param_groups.py b/tests/test_generate_param_groups.py index 122b22dba4..2130234013 100644 --- a/tests/test_generate_param_groups.py +++ b/tests/test_generate_param_groups.py @@ -21,7 +21,7 @@ TEST_CASE_1 = [ { "layer_matches": [lambda x: x.model[-1]], - "match_types": ["select"], + "match_types": "select", "lr_values": [1], }, (1, 100), @@ -30,7 +30,7 @@ TEST_CASE_2 = [ { "layer_matches": [lambda x: x.model[-1], lambda x: x.model[-2], lambda x: x.model[-3]], - "match_types": ["select", "select", "select"], + "match_types": "select", "lr_values": [1, 2, 3], }, (1, 2, 3, 100), @@ -84,6 +84,30 @@ def test_lr_values(self, input_param, expected_values): for param_group, value in zip(optimizer.param_groups, ensure_tuple(expected_values)): torch.testing.assert_allclose(param_group["lr"], value) + n = [len(p["params"]) for p in params] + assert sum(n) == 26 or all(n), "should have either full model or non-empty subsets." + + def test_wrong(self): + """overlapped""" + device = "cuda" if torch.cuda.is_available() else "cpu" + net = Unet( + dimensions=3, + in_channels=1, + out_channels=3, + channels=(16, 32, 64), + strides=(2, 2), + num_res_units=1, + ).to(device) + + params = generate_param_groups( + network=net, + layer_matches=[lambda x: x.model[-1], lambda x: x.model[-1]], + match_types="select", + lr_values=0.1, + ) + with self.assertRaises(ValueError): + torch.optim.Adam(params, 100) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_get_extreme_points.py b/tests/test_get_extreme_points.py new file mode 100644 index 0000000000..dd38af573e --- /dev/null +++ b/tests/test_get_extreme_points.py @@ -0,0 +1,48 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import get_extreme_points + +TEST_CASE_1 = [ + { + "img": np.array([[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]), + "rand_state": np.random, + "background": 0, + "pert": 0.0, + }, + [(0, 1), (3, 0), (3, 0), (1, 2)], +] + +TEST_CASE_2 = [ + { + "img": np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]), + "rand_state": np.random, + "background": 0, + "pert": 0.0, + }, + [(0, 1), (3, 1), (1, 0), (1, 2)], +] + + +class TestGetExtremePoints(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_type_shape(self, input_data, expected): + result = get_extreme_points(**input_data) + self.assertEqual(result, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_checkpoint_saver.py b/tests/test_handler_checkpoint_saver.py index 8513b6625e..2df36d9720 100644 --- a/tests/test_handler_checkpoint_saver.py +++ b/tests/test_handler_checkpoint_saver.py @@ -117,6 +117,7 @@ def _train_func(engine, batch): n_saved, ) handler.attach(engine) + engine.run(data, max_epochs=2) engine.run(data, max_epochs=5) for filename in filenames: self.assertTrue(os.path.exists(os.path.join(tempdir, filename))) diff --git a/tests/test_handler_hausdorff_distance.py b/tests/test_handler_hausdorff_distance.py new file mode 100644 index 0000000000..67322718b1 --- /dev/null +++ b/tests/test_handler_hausdorff_distance.py @@ -0,0 +1,88 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Tuple + +import numpy as np +import torch + +from monai.handlers import HausdorffDistance + + +def create_spherical_seg_3d( + radius: float = 20.0, + centre: Tuple[int, int, int] = (49, 49, 49), + im_shape: Tuple[int, int, int] = (99, 99, 99), +) -> np.ndarray: + """ + Return a 3D image with a sphere inside. Voxel values will be + 1 inside the sphere, and 0 elsewhere. + + Args: + radius: radius of sphere (in terms of number of voxels, can be partial) + centre: location of sphere centre. + im_shape: shape of image to create + + See also: + :py:meth:`~create_test_image_3d` + """ + # Create image + image = np.zeros(im_shape, dtype=np.int32) + spy, spx, spz = np.ogrid[ + -centre[0] : im_shape[0] - centre[0], -centre[1] : im_shape[1] - centre[1], -centre[2] : im_shape[2] - centre[2] + ] + circle = (spx * spx + spy * spy + spz * spz) <= radius * radius + + image[circle] = 1 + image[~circle] = 0 + return image + + +sampler_sphere = torch.Tensor(create_spherical_seg_3d(radius=20, centre=(20, 20, 20))).unsqueeze(0).unsqueeze(0) +sampler_sphere_gt = torch.Tensor(create_spherical_seg_3d(radius=20, centre=(10, 20, 20))).unsqueeze(0).unsqueeze(0) +sampler_sphere_zeros = torch.zeros_like(sampler_sphere) + +TEST_SAMPLE_1 = [sampler_sphere, sampler_sphere_gt] +TEST_SAMPLE_2 = [sampler_sphere_gt, sampler_sphere_gt] +TEST_SAMPLE_3 = [sampler_sphere_zeros, sampler_sphere_gt] +TEST_SAMPLE_4 = [sampler_sphere_zeros, sampler_sphere_zeros] + + +class TestHandlerHausdorffDistance(unittest.TestCase): + # TODO test multi node Hausdorff Distance + + def test_compute(self): + hd_metric = HausdorffDistance(include_background=True) + y_pred, y = TEST_SAMPLE_1 + hd_metric.update([y_pred, y]) + self.assertEqual(hd_metric.compute(), 10) + y_pred, y = TEST_SAMPLE_2 + hd_metric.update([y_pred, y]) + self.assertEqual(hd_metric.compute(), 5) + y_pred, y = TEST_SAMPLE_3 + hd_metric.update([y_pred, y]) + self.assertEqual(hd_metric.compute(), float("inf")) + self.assertEqual(hd_metric._num_examples, 3) + y_pred, y = TEST_SAMPLE_4 + hd_metric.update([y_pred, y]) + self.assertEqual(hd_metric._num_examples, 3) + + def test_shape_mismatch(self): + hd_metric = HausdorffDistance(include_background=True) + with self.assertRaises((AssertionError, ValueError)): + y_pred = TEST_SAMPLE_1[0] + y = torch.ones((1, 1, 10, 10, 10)) + hd_metric.update([y_pred, y]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_surface_distance.py b/tests/test_handler_surface_distance.py new file mode 100644 index 0000000000..02898769f6 --- /dev/null +++ b/tests/test_handler_surface_distance.py @@ -0,0 +1,88 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Tuple + +import numpy as np +import torch + +from monai.handlers import SurfaceDistance + + +def create_spherical_seg_3d( + radius: float = 20.0, + centre: Tuple[int, int, int] = (49, 49, 49), + im_shape: Tuple[int, int, int] = (99, 99, 99), +) -> np.ndarray: + """ + Return a 3D image with a sphere inside. Voxel values will be + 1 inside the sphere, and 0 elsewhere. + + Args: + radius: radius of sphere (in terms of number of voxels, can be partial) + centre: location of sphere centre. + im_shape: shape of image to create + + See also: + :py:meth:`~create_test_image_3d` + """ + # Create image + image = np.zeros(im_shape, dtype=np.int32) + spy, spx, spz = np.ogrid[ + -centre[0] : im_shape[0] - centre[0], -centre[1] : im_shape[1] - centre[1], -centre[2] : im_shape[2] - centre[2] + ] + circle = (spx * spx + spy * spy + spz * spz) <= radius * radius + + image[circle] = 1 + image[~circle] = 0 + return image + + +sampler_sphere = torch.Tensor(create_spherical_seg_3d(radius=20, centre=(20, 20, 20))).unsqueeze(0).unsqueeze(0) +sampler_sphere_gt = torch.Tensor(create_spherical_seg_3d(radius=20, centre=(10, 20, 20))).unsqueeze(0).unsqueeze(0) +sampler_sphere_zeros = torch.zeros_like(sampler_sphere) + +TEST_SAMPLE_1 = [sampler_sphere, sampler_sphere_gt] +TEST_SAMPLE_2 = [sampler_sphere_gt, sampler_sphere_gt] +TEST_SAMPLE_3 = [sampler_sphere_zeros, sampler_sphere_gt] +TEST_SAMPLE_4 = [sampler_sphere_zeros, sampler_sphere_zeros] + + +class TestHandlerSurfaceDistance(unittest.TestCase): + # TODO test multi node Surface Distance + + def test_compute(self): + sur_metric = SurfaceDistance(include_background=True) + y_pred, y = TEST_SAMPLE_1 + sur_metric.update([y_pred, y]) + self.assertAlmostEqual(sur_metric.compute(), 4.17133, places=4) + y_pred, y = TEST_SAMPLE_2 + sur_metric.update([y_pred, y]) + self.assertAlmostEqual(sur_metric.compute(), 2.08566, places=4) + y_pred, y = TEST_SAMPLE_3 + sur_metric.update([y_pred, y]) + self.assertAlmostEqual(sur_metric.compute(), float("inf")) + self.assertAlmostEqual(sur_metric._num_examples, 3) + y_pred, y = TEST_SAMPLE_4 + sur_metric.update([y_pred, y]) + self.assertAlmostEqual(sur_metric._num_examples, 3) + + def test_shape_mismatch(self): + sur_metric = SurfaceDistance(include_background=True) + with self.assertRaises((AssertionError, ValueError)): + y_pred = TEST_SAMPLE_1[0] + y = torch.ones((1, 1, 10, 10, 10)) + sur_metric.update([y_pred, y]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_hausdorff_distance.py b/tests/test_hausdorff_distance.py index dda1186612..96c52cbb68 100644 --- a/tests/test_hausdorff_distance.py +++ b/tests/test_hausdorff_distance.py @@ -13,27 +13,24 @@ from typing import Tuple import numpy as np +import torch from parameterized import parameterized -from monai.metrics import compute_hausdorff_distance +from monai.metrics import HausdorffDistanceMetric def create_spherical_seg_3d( radius: float = 20.0, centre: Tuple[int, int, int] = (49, 49, 49), - labelfield_value: int = 1, - background_value: int = 0, im_shape: Tuple[int, int, int] = (99, 99, 99), ) -> np.ndarray: """ Return a 3D image with a sphere inside. Voxel values will be - `labelfield_value` inside the sphere, and `background_value` elsewhere. + 1 inside the sphere, and 0 elsewhere. Args: radius: radius of sphere (in terms of number of voxels, can be partial) centre: location of sphere centre. - labelfield_value: index of labelfield. - background_value: index of background. im_shape: shape of image to create See also: @@ -46,8 +43,8 @@ def create_spherical_seg_3d( ] circle = (spx * spx + spy * spy + spz * spz) <= radius * radius - image[circle] = labelfield_value - image[~circle] = background_value + image[circle] = 1 + image[~circle] = 0 return image @@ -60,15 +57,13 @@ def create_spherical_seg_3d( [ create_spherical_seg_3d(radius=20, centre=(20, 20, 20)), create_spherical_seg_3d(radius=20, centre=(19, 19, 19)), - 1, ], [1.7320508075688772, 1.7320508075688772, 1, 1, 3, 3], ], [ [ - create_spherical_seg_3d(radius=33, labelfield_value=2, centre=(19, 33, 22)), - create_spherical_seg_3d(radius=33, labelfield_value=2, centre=(20, 33, 22)), - 2, + create_spherical_seg_3d(radius=33, centre=(19, 33, 22)), + create_spherical_seg_3d(radius=33, centre=(20, 33, 22)), ], [1, 1, 1, 1, 1, 1], ], @@ -76,31 +71,22 @@ def create_spherical_seg_3d( [ create_spherical_seg_3d(radius=20, centre=(20, 33, 22)), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), - 1, ], [20.09975124224178, 20.223748416156685, 15, 20, 24, 35], ], [ [ + # pred does not have foreground (but gt has), the metric should be inf np.zeros([99, 99, 99]), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), - 1, - ], - [np.inf, np.inf, np.inf, np.inf, np.inf, np.inf], - ], - [ - [ - np.zeros([99, 99, 99]), - np.zeros([99, 99, 99]), - 1, ], [np.inf, np.inf, np.inf, np.inf, np.inf, np.inf], ], [ [ + # gt does not have foreground (but pred has), the metric should be inf create_spherical_seg_3d(), np.zeros([99, 99, 99]), - 1, ], [np.inf, np.inf, np.inf, np.inf, np.inf, np.inf], ], @@ -108,32 +94,60 @@ def create_spherical_seg_3d( [ create_spherical_seg_3d(radius=20, centre=(20, 33, 22)), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), - 1, 95, ], [19.924858845171276, 20.09975124224178, 14, 18, 22, 33], ], ] +TEST_CASES_NANS = [ + [ + [ + # both pred and gt do not have foreground, metric and not_nans should be 0 + np.zeros([99, 99, 99]), + np.zeros([99, 99, 99]), + ], + ], +] + class TestHausdorffDistance(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_value(self, input_data, expected_value): percentile = None - if len(input_data) == 4: - [seg_1, seg_2, label_idx, percentile] = input_data + if len(input_data) == 3: + [seg_1, seg_2, percentile] = input_data else: - [seg_1, seg_2, label_idx] = input_data + [seg_1, seg_2] = input_data ct = 0 + seg_1 = torch.tensor(seg_1) + seg_2 = torch.tensor(seg_2) for metric in ["euclidean", "chessboard", "taxicab"]: for directed in [True, False]: - result = compute_hausdorff_distance( - seg_1, seg_2, label_idx, distance_metric=metric, percentile=percentile, directed=directed + hd_metric = HausdorffDistanceMetric( + include_background=False, distance_metric=metric, percentile=percentile, directed=directed ) + # shape of seg_1, seg_2 are: HWD, converts to BNHWD + batch, n_class = 2, 3 + batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) + batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) + result, _ = hd_metric(batch_seg_1, batch_seg_2) expected_value_curr = expected_value[ct] np.testing.assert_allclose(expected_value_curr, result, rtol=1e-7) ct += 1 + @parameterized.expand(TEST_CASES_NANS) + def test_nans(self, input_data): + [seg_1, seg_2] = input_data + seg_1 = torch.tensor(seg_1) + seg_2 = torch.tensor(seg_2) + hd_metric = HausdorffDistanceMetric(include_background=False) + batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0) + batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0) + result, not_nans = hd_metric(batch_seg_1, batch_seg_2) + np.testing.assert_allclose(0, result, rtol=1e-7) + np.testing.assert_allclose(0, not_nans, rtol=1e-7) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_highresnet.py b/tests/test_highresnet.py index 6a4b129588..10f4f41fea 100644 --- a/tests/test_highresnet.py +++ b/tests/test_highresnet.py @@ -53,7 +53,7 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) - @TimedCall(seconds=100, force_quit=True) + @TimedCall(seconds=400, force_quit=True) def test_script(self): input_param, input_shape, expected_shape = TEST_CASE_1 net = HighResNet(**input_param) diff --git a/tests/test_hilbert_transform.py b/tests/test_hilbert_transform.py new file mode 100644 index 0000000000..1099468102 --- /dev/null +++ b/tests/test_hilbert_transform.py @@ -0,0 +1,225 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers import HilbertTransform +from monai.utils import InvalidPyTorchVersionError, OptionalImportError +from tests.utils import ( + SkipIfAtLeastPyTorchVersion, + SkipIfBeforePyTorchVersion, + SkipIfModule, + SkipIfNoModule, + skip_if_no_cuda, +) + + +def create_expected_numpy_output(input_datum, **kwargs): + + x = np.fft.fft( + input_datum.cpu().numpy() if input_datum.device.type == "cuda" else input_datum.numpy(), + **kwargs, + ) + f = np.fft.fftfreq(x.shape[kwargs["axis"]]) + u = np.heaviside(f, 0.5) + new_dims_before = kwargs["axis"] + new_dims_after = len(x.shape) - kwargs["axis"] - 1 + for _ in range(new_dims_before): + u = np.expand_dims(u, 0) + for _ in range(new_dims_after): + u = np.expand_dims(u, -1) + ht = np.fft.ifft(x * 2 * u, axis=kwargs["axis"]) + + return ht + + +cpu = torch.device("cpu") +n_samples = 500 +hann_windowed_sine = np.sin(2 * np.pi * 10 * np.linspace(0, 1, n_samples)) * np.hanning(n_samples) + +# CPU TEST DATA + +cpu_input_data = dict() +cpu_input_data["1D"] = torch.as_tensor(hann_windowed_sine, device=cpu).unsqueeze(0).unsqueeze(0) +cpu_input_data["2D"] = ( + torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu).unsqueeze(0).unsqueeze(0) +) +cpu_input_data["3D"] = ( + torch.as_tensor(np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=cpu) + .unsqueeze(0) + .unsqueeze(0) +) +cpu_input_data["1D 2CH"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu).unsqueeze(0) +cpu_input_data["2D 2CH"] = torch.as_tensor( + np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=cpu +).unsqueeze(0) + +# SINGLE-CHANNEL CPU VALUE TESTS + +TEST_CASE_1D_SINE_CPU = [ + {}, # args (empty, so use default) + cpu_input_data["1D"], # Input data: Random 1D signal + create_expected_numpy_output(cpu_input_data["1D"], axis=2), # Expected output: FFT of signal + 1e-5, # absolute tolerance +] + +TEST_CASE_2D_SINE_CPU = [ + {}, # args (empty, so use default) + cpu_input_data["2D"], # Input data: Random 1D signal + create_expected_numpy_output(cpu_input_data["2D"], axis=2), # Expected output: FFT of signal + 1e-5, # absolute tolerance +] + +TEST_CASE_3D_SINE_CPU = [ + {}, # args (empty, so use default) + cpu_input_data["3D"], # Input data: Random 1D signal + create_expected_numpy_output(cpu_input_data["3D"], axis=2), + 1e-5, # absolute tolerance +] + +# MULTICHANNEL CPU VALUE TESTS, PROCESS ALONG FIRST SPATIAL AXIS + +TEST_CASE_1D_2CH_SINE_CPU = [ + {}, # args (empty, so use default) + cpu_input_data["1D 2CH"], # Input data: Random 1D signal + create_expected_numpy_output(cpu_input_data["1D 2CH"], axis=2), + 1e-5, # absolute tolerance +] + +TEST_CASE_2D_2CH_SINE_CPU = [ + {}, # args (empty, so use default) + cpu_input_data["2D 2CH"], # Input data: Random 1D signal + create_expected_numpy_output(cpu_input_data["2D 2CH"], axis=2), + 1e-5, # absolute tolerance +] + +# GPU TEST DATA + +if torch.cuda.is_available(): + gpu = torch.device("cuda") + + gpu_input_data = dict() + gpu_input_data["1D"] = torch.as_tensor(hann_windowed_sine, device=gpu).unsqueeze(0).unsqueeze(0) + gpu_input_data["2D"] = ( + torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=gpu).unsqueeze(0).unsqueeze(0) + ) + gpu_input_data["3D"] = ( + torch.as_tensor(np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=gpu) + .unsqueeze(0) + .unsqueeze(0) + ) + gpu_input_data["1D 2CH"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=gpu).unsqueeze(0) + gpu_input_data["2D 2CH"] = torch.as_tensor( + np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=gpu + ).unsqueeze(0) + + # SINGLE CHANNEL GPU VALUE TESTS + + TEST_CASE_1D_SINE_GPU = [ + {}, # args (empty, so use default) + gpu_input_data["1D"], # Input data: Random 1D signal + create_expected_numpy_output(gpu_input_data["1D"], axis=2), # Expected output: FFT of signal + 1e-5, # absolute tolerance + ] + + TEST_CASE_2D_SINE_GPU = [ + {}, # args (empty, so use default) + gpu_input_data["2D"], # Input data: Random 1D signal + create_expected_numpy_output(gpu_input_data["2D"], axis=2), # Expected output: FFT of signal + 1e-5, # absolute tolerance + ] + + TEST_CASE_3D_SINE_GPU = [ + {}, # args (empty, so use default) + gpu_input_data["3D"], # Input data: Random 1D signal + create_expected_numpy_output(gpu_input_data["3D"], axis=2), # Expected output: FFT of signal + 1e-5, # absolute tolerance + ] + + # MULTICHANNEL GPU VALUE TESTS, PROCESS ALONG FIRST SPATIAL AXIS + + TEST_CASE_1D_2CH_SINE_GPU = [ + {}, # args (empty, so use default) + gpu_input_data["1D 2CH"], # Input data: Random 1D signal + create_expected_numpy_output(gpu_input_data["1D 2CH"], axis=2), + 1e-5, # absolute tolerance + ] + + TEST_CASE_2D_2CH_SINE_GPU = [ + {}, # args (empty, so use default) + gpu_input_data["2D 2CH"], # Input data: Random 1D signal + create_expected_numpy_output(gpu_input_data["2D 2CH"], axis=2), + 1e-5, # absolute tolerance + ] + +# TESTS CHECKING PADDING, AXIS SELECTION ETC ARE COVERED BY test_detect_envelope.py + + +@SkipIfBeforePyTorchVersion((1, 7)) +@SkipIfNoModule("torch.fft") +class TestHilbertTransformCPU(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_1D_SINE_CPU, + TEST_CASE_2D_SINE_CPU, + TEST_CASE_3D_SINE_CPU, + TEST_CASE_1D_2CH_SINE_CPU, + TEST_CASE_2D_2CH_SINE_CPU, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = HilbertTransform(**arguments)(image) + result = result.squeeze(0).squeeze(0).numpy() + np.testing.assert_allclose(result, expected_data.squeeze(), atol=atol) + + +@skip_if_no_cuda +@SkipIfBeforePyTorchVersion((1, 7)) +@SkipIfNoModule("torch.fft") +class TestHilbertTransformGPU(unittest.TestCase): + @parameterized.expand( + [] + if not torch.cuda.is_available() + else [ + TEST_CASE_1D_SINE_GPU, + TEST_CASE_2D_SINE_GPU, + TEST_CASE_3D_SINE_GPU, + TEST_CASE_1D_2CH_SINE_GPU, + TEST_CASE_2D_2CH_SINE_GPU, + ], + skip_on_empty=True, + ) + def test_value(self, arguments, image, expected_data, atol): + result = HilbertTransform(**arguments)(image) + result = result.squeeze(0).squeeze(0).cpu().numpy() + np.testing.assert_allclose(result, expected_data.squeeze(), atol=atol) + + +@SkipIfBeforePyTorchVersion((1, 7)) +@SkipIfModule("torch.fft") +class TestHilbertTransformNoFFTMod(unittest.TestCase): + def test_no_fft_module_error(self): + self.assertRaises(OptionalImportError, HilbertTransform(), torch.randn(1, 1, 10)) + + +@SkipIfAtLeastPyTorchVersion((1, 7)) +class TestHilbertTransformInvalidPyTorch(unittest.TestCase): + def test_invalid_pytorch_error(self): + with self.assertRaisesRegex(InvalidPyTorchVersionError, "version"): + HilbertTransform() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_init_reader.py b/tests/test_init_reader.py new file mode 100644 index 0000000000..87777d83a3 --- /dev/null +++ b/tests/test_init_reader.py @@ -0,0 +1,48 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from monai.data import ITKReader, NibabelReader, NumpyReader, PILReader +from monai.transforms import LoadImage, LoadImaged + + +class TestInitLoadImage(unittest.TestCase): + def test_load_image(self): + instance1 = LoadImage(image_only=False, dtype=None) + instance2 = LoadImage(image_only=True, dtype=None) + self.assertIsInstance(instance1, LoadImage) + self.assertIsInstance(instance2, LoadImage) + + for r in ["NibabelReader", "PILReader", "ITKReader", "NumpyReader", None]: + inst = LoadImaged("image", reader=r) + self.assertIsInstance(inst, LoadImaged) + + def test_readers(self): + inst = ITKReader() + self.assertIsInstance(inst, ITKReader) + + inst = NibabelReader() + self.assertIsInstance(inst, NibabelReader) + inst = NibabelReader(as_closest_canonical=True) + self.assertIsInstance(inst, NibabelReader) + + inst = NumpyReader() + self.assertIsInstance(inst, NumpyReader) + inst = NumpyReader(npz_keys="test") + self.assertIsInstance(inst, NumpyReader) + + inst = PILReader() + self.assertIsInstance(inst, PILReader) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_integration_classification_2d.py b/tests/test_integration_classification_2d.py index aadf379dcb..aa0fd57f76 100644 --- a/tests/test_integration_classification_2d.py +++ b/tests/test_integration_classification_2d.py @@ -243,7 +243,7 @@ def test_training(self): repeated.append(results) np.testing.assert_allclose(repeated[0], repeated[1]) - @TimedCall(seconds=1000, skip_timing=not torch.cuda.is_available(), daemon=False, force_quit=False) + @TimedCall(seconds=1000, skip_timing=not torch.cuda.is_available(), daemon=False) def test_timing(self): self.train_and_infer() diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 2a0b5d5d86..8e96947ccb 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -296,7 +296,7 @@ def train_and_infer(self, idx=0): def test_training(self): repeated = [] - test_rounds = 3 if monai.config.get_torch_version_tuple() >= (1, 6) else 2 + test_rounds = 3 if monai.utils.module.get_torch_version_tuple() >= (1, 6) else 2 for i in range(test_rounds): results = self.train_and_infer(idx=i) repeated.append(results) @@ -308,7 +308,7 @@ def test_training(self): daemon=False, ) def test_timing(self): - if monai.config.get_torch_version_tuple() >= (1, 6): + if monai.utils.module.get_torch_version_tuple() >= (1, 6): self.train_and_infer(idx=2) diff --git a/tests/test_lmdbdataset.py b/tests/test_lmdbdataset.py index b867e31e20..e4d79ad4bd 100644 --- a/tests/test_lmdbdataset.py +++ b/tests/test_lmdbdataset.py @@ -144,11 +144,15 @@ def test_shape(self, transform, expected_shape, kwargs=None): ] cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") - dataset_precached = LMDBDataset(data=test_data, transform=transform, cache_dir=cache_dir, **kwargs) + dataset_precached = LMDBDataset( + data=test_data, transform=transform, progress=False, cache_dir=cache_dir, **kwargs + ) data1_precached = dataset_precached[0] data2_precached = dataset_precached[1] - dataset_postcached = LMDBDataset(data=test_data, transform=transform, cache_dir=cache_dir, **kwargs) + dataset_postcached = LMDBDataset( + data=test_data, transform=transform, progress=False, cache_dir=cache_dir, **kwargs + ) data1_postcached = dataset_postcached[0] data2_postcached = dataset_postcached[1] diff --git a/tests/test_normalize_intensity.py b/tests/test_normalize_intensity.py index a5021c5f26..06768f77b7 100644 --- a/tests/test_normalize_intensity.py +++ b/tests/test_normalize_intensity.py @@ -59,6 +59,7 @@ class TestNormalizeIntensity(NumpyImageTestCase2D): def test_default(self): normalizer = NormalizeIntensity() normalized = normalizer(self.imt) + self.assertTrue(normalized.dtype == np.float32) expected = (self.imt - np.mean(self.imt)) / np.std(self.imt) np.testing.assert_allclose(normalized, expected, rtol=1e-6) diff --git a/tests/test_surface_distance.py b/tests/test_surface_distance.py index 8b16dc4f35..dca3aaec12 100644 --- a/tests/test_surface_distance.py +++ b/tests/test_surface_distance.py @@ -13,27 +13,24 @@ from typing import Tuple import numpy as np +import torch from parameterized import parameterized -from monai.metrics import compute_average_surface_distance +from monai.metrics import SurfaceDistanceMetric def create_spherical_seg_3d( radius: float = 20.0, centre: Tuple[int, int, int] = (49, 49, 49), - labelfield_value: int = 1, - background_value: int = 0, im_shape: Tuple[int, int, int] = (99, 99, 99), ) -> np.ndarray: """ Return a 3D image with a sphere inside. Voxel values will be - `labelfield_value` inside the sphere, and `background_value` elsewhere. + 1 inside the sphere, and 0 elsewhere. Args: radius: radius of sphere (in terms of number of voxels, can be partial) centre: location of sphere centre. - labelfield_value: index of labelfield. - background_value: index of background. im_shape: shape of image to create See also: @@ -46,30 +43,28 @@ def create_spherical_seg_3d( ] circle = (spx * spx + spy * spy + spz * spz) <= radius * radius - image[circle] = labelfield_value - image[~circle] = background_value + image[circle] = 1 + image[~circle] = 0 return image TEST_CASES = [ [ - [create_spherical_seg_3d(), create_spherical_seg_3d(), 1], + [create_spherical_seg_3d(), create_spherical_seg_3d()], [0, 0], ], [ [ create_spherical_seg_3d(radius=20, centre=(20, 20, 20)), create_spherical_seg_3d(radius=20, centre=(19, 19, 19)), - 1, "taxicab", ], [1.0380029806259314, 1.0380029806259314], ], [ [ - create_spherical_seg_3d(radius=33, labelfield_value=2, centre=(19, 33, 22)), - create_spherical_seg_3d(radius=33, labelfield_value=2, centre=(20, 33, 22)), - 2, + create_spherical_seg_3d(radius=33, centre=(19, 33, 22)), + create_spherical_seg_3d(radius=33, centre=(20, 33, 22)), ], [0.35021200688332677, 0.3483278807706289], ], @@ -77,7 +72,6 @@ def create_spherical_seg_3d( [ create_spherical_seg_3d(radius=20, centre=(20, 33, 22)), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), - 1, ], [13.975673696300824, 12.040033513150455], ], @@ -85,7 +79,6 @@ def create_spherical_seg_3d( [ create_spherical_seg_3d(radius=20, centre=(20, 33, 22)), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), - 1, "chessboard", ], [10.792254295459173, 9.605067064083457], @@ -94,7 +87,6 @@ def create_spherical_seg_3d( [ create_spherical_seg_3d(radius=20, centre=(20, 33, 22)), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), - 1, "taxicab", ], [17.32691760951026, 12.432687531048186], @@ -103,26 +95,26 @@ def create_spherical_seg_3d( [ np.zeros([99, 99, 99]), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), - 1, ], [np.inf, np.inf], ], [ [ + create_spherical_seg_3d(), np.zeros([99, 99, 99]), - np.zeros([99, 99, 99]), - 1, + "taxicab", ], [np.inf, np.inf], ], +] + +TEST_CASES_NANS = [ [ [ - create_spherical_seg_3d(), + # both pred and gt do not have foreground, metric and not_nans should be 0 + np.zeros([99, 99, 99]), np.zeros([99, 99, 99]), - 1, - "taxicab", ], - [np.inf, np.inf], ], ] @@ -130,20 +122,37 @@ def create_spherical_seg_3d( class TestAllSurfaceMetrics(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_value(self, input_data, expected_value): - if len(input_data) == 4: - [seg_1, seg_2, label_idx, metric] = input_data + if len(input_data) == 3: + [seg_1, seg_2, metric] = input_data else: - [seg_1, seg_2, label_idx] = input_data + [seg_1, seg_2] = input_data metric = "euclidean" ct = 0 + seg_1 = torch.tensor(seg_1) + seg_2 = torch.tensor(seg_2) for symmetric in [True, False]: + sur_metric = SurfaceDistanceMetric(include_background=False, symmetric=symmetric, distance_metric=metric) + # shape of seg_1, seg_2 are: HWD, converts to BNHWD + batch, n_class = 2, 3 + batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) + batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) + result, _ = sur_metric(batch_seg_1, batch_seg_2) expected_value_curr = expected_value[ct] - result = compute_average_surface_distance( - seg_1, seg_2, label_idx, symmetric=symmetric, distance_metric=metric - ) np.testing.assert_allclose(expected_value_curr, result, rtol=1e-7) ct += 1 + @parameterized.expand(TEST_CASES_NANS) + def test_nans(self, input_data): + [seg_1, seg_2] = input_data + seg_1 = torch.tensor(seg_1) + seg_2 = torch.tensor(seg_2) + sur_metric = SurfaceDistanceMetric(include_background=False) + batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0) + batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0) + result, not_nans = sur_metric(batch_seg_1, batch_seg_2) + np.testing.assert_allclose(0, result, rtol=1e-7) + np.testing.assert_allclose(0, not_nans, rtol=1e-7) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_unet.py b/tests/test_unet.py index 5d95e66ba4..ed05fce552 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -72,7 +72,7 @@ (16, 3, 32, 64, 48), ] -TEST_CASE_4 = [ # 4-channel 3D, batch 16, batch normalisation +TEST_CASE_4 = [ # 4-channel 3D, batch 16, batch normalization { "dimensions": 3, "in_channels": 4, diff --git a/tests/utils.py b/tests/utils.py index 3ab73a4fcd..0b6c4e7318 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -28,10 +28,13 @@ import torch import torch.distributed as dist +from monai.config.deviceconfig import USE_COMPILED from monai.data import create_test_image_2d, create_test_image_3d -from monai.utils import optional_import, set_determinism +from monai.utils import ensure_tuple, optional_import, set_determinism +from monai.utils.module import get_torch_version_tuple nib, _ = optional_import("nibabel") +ver, has_pkg_res = optional_import("pkg_resources", name="parse_version") quick_test_var = "QUICKTEST" @@ -66,6 +69,25 @@ def __call__(self, obj): return unittest.skipIf(self.module_missing, f"optional module not present: {self.module_name}")(obj) +class SkipIfModule(object): + """Decorator to be used if test should be skipped + when optional module is present.""" + + def __init__(self, module_name): + self.module_name = module_name + self.module_avail = optional_import(self.module_name)[1] + + def __call__(self, obj): + return unittest.skipIf(self.module_avail, f"Skipping because optional module present: {self.module_name}")(obj) + + +def skip_if_no_cpp_extention(obj): + """ + Skip the unit tests if the cpp extention isnt available + """ + return unittest.skipIf(not USE_COMPILED, "Skipping cpp extention tests")(obj) + + def skip_if_no_cuda(obj): """ Skip the unit tests if torch.cuda.is_available is False @@ -80,6 +102,40 @@ def skip_if_windows(obj): return unittest.skipIf(sys.platform == "win32", "Skipping tests on Windows")(obj) +class SkipIfBeforePyTorchVersion(object): + """Decorator to be used if test should be skipped + with PyTorch versions older than that given.""" + + def __init__(self, pytorch_version_tuple): + self.min_version = pytorch_version_tuple + if has_pkg_res: + self.version_too_old = ver(torch.__version__) < ver(".".join(map(str, self.min_version))) + else: + self.version_too_old = get_torch_version_tuple() < self.min_version + + def __call__(self, obj): + return unittest.skipIf( + self.version_too_old, f"Skipping tests that fail on PyTorch versions before: {self.min_version}" + )(obj) + + +class SkipIfAtLeastPyTorchVersion(object): + """Decorator to be used if test should be skipped + with PyTorch versions older than that given.""" + + def __init__(self, pytorch_version_tuple): + self.max_version = pytorch_version_tuple + if has_pkg_res: + self.version_too_new = ver(torch.__version__) >= ver(".".join(map(str, self.max_version))) + else: + self.version_too_new = get_torch_version_tuple() >= self.max_version + + def __call__(self, obj): + return unittest.skipIf( + self.version_too_new, f"Skipping tests that fail on PyTorch versions at least: {self.max_version}" + )(obj) + + def make_nifti_image(array, affine=None): """ Create a temporary nifti image on the disk and return the image name. @@ -457,11 +513,10 @@ def test_script_save(net, *inputs, eval_nets=True, device=None, rtol=1e-4): result1 = net(*inputs) result2 = reloaded_net(*inputs) set_determinism(seed=None) - # When using e.g., VAR, we will produce a tuple of outputs. - # Hence, convert all to tuples and then compare all elements. - if not isinstance(result1, tuple): - result1 = (result1,) - result2 = (result2,) + + # convert results to tuples if needed to allow iterating over pairs of outputs + result1 = ensure_tuple(result1) + result2 = ensure_tuple(result2) for i, (r1, r2) in enumerate(zip(result1, result2)): if None not in (r1, r2): # might be None