diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 3d1683d470..a4295c013c 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -8,6 +8,10 @@ on: pull_request: jobs: + # caching of these jobs: + # - docker-20-03-py3-pip- (shared) + # - ubuntu py37 pip- + # - os-latest-pip- (shared) flake8-py3: runs-on: ubuntu-latest steps: @@ -16,10 +20,20 @@ jobs: uses: actions/setup-python@v1 with: python-version: 3.7 + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} - name: Install dependencies run: | python -m pip install --upgrade pip wheel - pip install -r requirements-dev.txt + python -m pip install -r requirements-dev.txt - name: Lint with black formater run: | $(pwd)/runtests.sh --nounittests --black @@ -51,6 +65,18 @@ jobs: run: | which python python -m pip install --upgrade pip wheel + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + echo "::set-output name=dir::$(pip cache dir)" + shell: bash + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ matrix.os }}-latest-pip-${{ steps.pip-cache.outputs.datew }} - if: runner.os == 'windows' name: Install torch cpu from pytorch.org (Windows only) run: | @@ -63,6 +89,7 @@ jobs: cat "requirements-dev.txt" python -m pip install -r requirements-dev.txt python -m pip list + python setup.py develop # compile the cpp extensions - name: Run quick tests (CPU ${{ runner.os }}) run: | python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' @@ -77,6 +104,16 @@ jobs: runs-on: [self-hosted, linux, x64] steps: - uses: actions/checkout@v2 + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ~/.cache/pip + key: docker-20-03-py3-pip-${{ steps.pip-cache.outputs.datew }} - name: Install the dependencies run: | which python @@ -93,7 +130,6 @@ jobs: python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))" python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))' ./runtests.sh --quick - pip list coverage xml - name: Upload coverage uses: codecov/codecov-action@v1 @@ -110,9 +146,23 @@ jobs: uses: actions/setup-python@v1 with: python-version: '3.x' - - name: Install setuptools + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} + - name: Install dependencies run: | python -m pip install --user --upgrade pip setuptools wheel twine + # install the latest pytorch for testing + # however, "pip install monai*.tar.gz" will build cpp/cuda with an isolated + # fresh torch installation according to pyproject.toml + python -m pip install torch>=1.4 - name: Test source archive and wheel file run: | git fetch --depth=1 origin +refs/tags/*:refs/tags/* @@ -158,10 +208,20 @@ jobs: uses: actions/setup-python@v1 with: python-version: 3.7 + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} - name: Install dependencies run: | python -m pip install --upgrade pip wheel - pip install -r docs/requirements.txt + python -m pip install -r docs/requirements.txt - name: Make html run: | cd docs/ diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index fed6de46a7..3c2c3b4168 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -7,6 +7,10 @@ on: - master jobs: + # caching of these jobs: + # - docker-20-03-py3-pip- (shared) + # - ubuntu py36 37 38-pip- + # - os-latest-pip (shared) coverage-py3: container: image: nvcr.io/nvidia/pytorch:20.03-py3 @@ -14,6 +18,16 @@ jobs: runs-on: [self-hosted, linux, x64] steps: - uses: actions/checkout@v2 + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ~/.cache/pip + key: docker-20-03-py3-pip-${{ steps.pip-cache.outputs.datew }} - name: Install the dependencies run: | which python @@ -21,9 +35,9 @@ jobs: python -m pip uninstall -y torch torchvision python -m pip install torch==1.4 python -m pip install -r requirements-dev.txt - python -m pip list - name: Run unit tests report coverage run: | + python -m pip list nvidia-smi export CUDA_VISIBLE_DEVICES=$(python -m tests.utils) echo $CUDA_VISIBLE_DEVICES @@ -50,6 +64,16 @@ jobs: uses: actions/setup-python@v1 with: python-version: ${{ matrix.python-version }} + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ~/.cache/pip + key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ steps.pip-cache.outputs.datew }} - name: Install the dependencies run: | python -m pip install --upgrade pip wheel @@ -57,6 +81,7 @@ jobs: python -m pip install -r requirements-dev.txt - name: Run quick tests CPU ubuntu run: | + python -m pip list python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' ./runtests.sh --quick coverage xml @@ -83,6 +108,18 @@ jobs: run: | which python python -m pip install --upgrade pip wheel + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + echo "::set-output name=dir::$(pip cache dir)" + shell: bash + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ matrix.os }}-latest-pip-${{ steps.pip-cache.outputs.datew }} - if: runner.os == 'windows' name: Install torch cpu from pytorch.org (Windows only) run: | @@ -95,6 +132,7 @@ jobs: cat "requirements-dev.txt" python -m pip install -r requirements-dev.txt python -m pip list + python setup.py develop # to compile extensions using the system native torch - name: Run quick tests (CPU ${{ runner.os }}) run: | python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' @@ -109,6 +147,16 @@ jobs: uses: actions/setup-python@v1 with: python-version: 3.7 + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} - name: Install the default branch run: | pip install git+https://github.com/Project-MONAI/MONAI#egg=MONAI diff --git a/docs/source/installation.md b/docs/source/installation.md index 9f464a6863..1148d7a7c6 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -54,10 +54,13 @@ This command will create a ``MONAI/`` folder in your current directory. You can install it by running: ```bash cd MONAI/ -pip install -e . +python setup.py develop + +# to uninstall the package please run: +python setup.py develop --uninstall ``` or simply adding the root directory of the cloned source code (e.g., ``/workspace/Documents/MONAI``) to your ``$PYTHONPATH`` -and the codebase is ready to use. +and the codebase is ready to use (without the additional features of MONAI C++/CUDA extensions). ## Validating the install @@ -128,7 +131,7 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is ``` [nibabel, skimage, pillow, tensorboard, gdown, ignite] ``` -which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, +which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, `gdown`, and `pytorch-ignite` respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/docs/source/networks.rst b/docs/source/networks.rst index d7f1c4e3dd..7fa23a4f64 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -129,6 +129,11 @@ Layers .. autoclass:: monai.networks.layers.AffineTransform :members: +`LLTM` +~~~~~~ +.. autoclass:: LLTM + :members: + `Utilities` ~~~~~~~~~~~ .. automodule:: monai.networks.layers.convutils diff --git a/monai/__init__.py b/monai/__init__.py index a179c73532..3bbb327206 100644 --- a/monai/__init__.py +++ b/monai/__init__.py @@ -22,7 +22,9 @@ __basedir__ = os.path.dirname(__file__) -excludes = "^(handlers)" # the handlers have some external decorators the users may not have installed +# handlers_* have some external decorators the users may not have installed +# *.so files and folder "_C" may not exist when the cpp extensions are not compiled +excludes = "(^(handlers))|((\\.so)$)|(_C)" # load directory modules only, skip loading individual files load_submodules(sys.modules[__name__], False, exclude_pattern=excludes) diff --git a/monai/networks/extensions/lltm/lltm.cpp b/monai/networks/extensions/lltm/lltm.cpp new file mode 100644 index 0000000000..c8e9077838 --- /dev/null +++ b/monai/networks/extensions/lltm/lltm.cpp @@ -0,0 +1,103 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include + +#include + +// s'(z) = (1 - s(z)) * s(z) +torch::Tensor d_sigmoid(torch::Tensor z) { + auto s = torch::sigmoid(z); + return (1 - s) * s; +} + +// tanh'(z) = 1 - tanh^2(z) +torch::Tensor d_tanh(torch::Tensor z) { + return 1 - z.tanh().pow(2); +} + +// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0} +torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) { + auto e = z.exp(); + auto mask = (alpha * (e - 1)) < 0; + return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e); +} + +std::vector lltm_forward( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + torch::Tensor old_h, + torch::Tensor old_cell) { + auto X = torch::cat({old_h, input}, /*dim=*/1); + + auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); + auto gates = gate_weights.chunk(3, /*dim=*/1); + + auto input_gate = torch::sigmoid(gates[0]); + auto output_gate = torch::sigmoid(gates[1]); + auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0); + + auto new_cell = old_cell + candidate_cell * input_gate; + auto new_h = torch::tanh(new_cell) * output_gate; + + return {new_h, + new_cell, + input_gate, + output_gate, + candidate_cell, + X, + gate_weights}; +} + +std::vector lltm_backward( + torch::Tensor grad_h, + torch::Tensor grad_cell, + torch::Tensor new_cell, + torch::Tensor input_gate, + torch::Tensor output_gate, + torch::Tensor candidate_cell, + torch::Tensor X, + torch::Tensor gate_weights, + torch::Tensor weights) { + auto d_output_gate = torch::tanh(new_cell) * grad_h; + auto d_tanh_new_cell = output_gate * grad_h; + auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell; + + auto d_old_cell = d_new_cell; + auto d_candidate_cell = input_gate * d_new_cell; + auto d_input_gate = candidate_cell * d_new_cell; + + auto gates = gate_weights.chunk(3, /*dim=*/1); + d_input_gate *= d_sigmoid(gates[0]); + d_output_gate *= d_sigmoid(gates[1]); + d_candidate_cell *= d_elu(gates[2]); + + auto d_gates = + torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1); + + auto d_weights = d_gates.t().mm(X); + auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true); + + auto d_X = d_gates.mm(weights); + const auto state_size = grad_h.size(1); + auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); + auto d_input = d_X.slice(/*dim=*/1, state_size); + + return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("lltm_forward", &lltm_forward, "LLTM forward"); + m.def("lltm_backward", &lltm_backward, "LLTM backward"); +} diff --git a/monai/networks/extensions/lltm/lltm_cuda.cpp b/monai/networks/extensions/lltm/lltm_cuda.cpp new file mode 100644 index 0000000000..ed27c25835 --- /dev/null +++ b/monai/networks/extensions/lltm/lltm_cuda.cpp @@ -0,0 +1,94 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include + +#include + +// CUDA forward declarations + +std::vector lltm_cuda_forward( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + torch::Tensor old_h, + torch::Tensor old_cell); + +std::vector lltm_cuda_backward( + torch::Tensor grad_h, + torch::Tensor grad_cell, + torch::Tensor new_cell, + torch::Tensor input_gate, + torch::Tensor output_gate, + torch::Tensor candidate_cell, + torch::Tensor X, + torch::Tensor gate_weights, + torch::Tensor weights); + +// C++ interface + +// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::vector lltm_forward( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + torch::Tensor old_h, + torch::Tensor old_cell) { + CHECK_INPUT(input); + CHECK_INPUT(weights); + CHECK_INPUT(bias); + CHECK_INPUT(old_h); + CHECK_INPUT(old_cell); + + return lltm_cuda_forward(input, weights, bias, old_h, old_cell); +} + +std::vector lltm_backward( + torch::Tensor grad_h, + torch::Tensor grad_cell, + torch::Tensor new_cell, + torch::Tensor input_gate, + torch::Tensor output_gate, + torch::Tensor candidate_cell, + torch::Tensor X, + torch::Tensor gate_weights, + torch::Tensor weights) { + CHECK_INPUT(grad_h); + CHECK_INPUT(grad_cell); + CHECK_INPUT(input_gate); + CHECK_INPUT(output_gate); + CHECK_INPUT(candidate_cell); + CHECK_INPUT(X); + CHECK_INPUT(gate_weights); + CHECK_INPUT(weights); + + return lltm_cuda_backward( + grad_h, + grad_cell, + new_cell, + input_gate, + output_gate, + candidate_cell, + X, + gate_weights, + weights); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("lltm_forward", &lltm_forward, "LLTM forward (CUDA)"); + m.def("lltm_backward", &lltm_backward, "LLTM backward (CUDA)"); +} diff --git a/monai/networks/extensions/lltm/lltm_cuda_kernel.cu b/monai/networks/extensions/lltm/lltm_cuda_kernel.cu new file mode 100644 index 0000000000..dd9aeeb024 --- /dev/null +++ b/monai/networks/extensions/lltm/lltm_cuda_kernel.cu @@ -0,0 +1,187 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include + +#include +#include + +#include + +namespace { +template +__device__ __forceinline__ scalar_t sigmoid(scalar_t z) { + return 1.0 / (1.0 + exp(-z)); +} + +template +__device__ __forceinline__ scalar_t d_sigmoid(scalar_t z) { + const auto s = sigmoid(z); + return (1.0 - s) * s; +} + +template +__device__ __forceinline__ scalar_t d_tanh(scalar_t z) { + const auto t = tanh(z); + return 1 - (t * t); +} + +template +__device__ __forceinline__ scalar_t elu(scalar_t z, scalar_t alpha = 1.0) { + return fmaxf(0.0, z) + fminf(0.0, alpha * (exp(z) - 1.0)); +} + +template +__device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) { + const auto e = exp(z); + const auto d_relu = z < 0.0 ? 0.0 : 1.0; + return d_relu + (((alpha * (e - 1.0)) < 0.0) ? (alpha * e) : 0.0); +} + +template +__global__ void lltm_cuda_forward_kernel( + const torch::PackedTensorAccessor gates, + const torch::PackedTensorAccessor old_cell, + torch::PackedTensorAccessor new_h, + torch::PackedTensorAccessor new_cell, + torch::PackedTensorAccessor input_gate, + torch::PackedTensorAccessor output_gate, + torch::PackedTensorAccessor candidate_cell) { + //batch index + const int n = blockIdx.y; + // column index + const int c = blockIdx.x * blockDim.x + threadIdx.x; + if (c < gates.size(2)){ + input_gate[n][c] = sigmoid(gates[n][0][c]); + output_gate[n][c] = sigmoid(gates[n][1][c]); + candidate_cell[n][c] = elu(gates[n][2][c]); + new_cell[n][c] = + old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c]; + new_h[n][c] = tanh(new_cell[n][c]) * output_gate[n][c]; + } +} + +template +__global__ void lltm_cuda_backward_kernel( + torch::PackedTensorAccessor d_old_cell, + torch::PackedTensorAccessor d_gates, + const torch::PackedTensorAccessor grad_h, + const torch::PackedTensorAccessor grad_cell, + const torch::PackedTensorAccessor new_cell, + const torch::PackedTensorAccessor input_gate, + const torch::PackedTensorAccessor output_gate, + const torch::PackedTensorAccessor candidate_cell, + const torch::PackedTensorAccessor gate_weights) { + //batch index + const int n = blockIdx.y; + // column index + const int c = blockIdx.x * blockDim.x + threadIdx.x; + if (c < d_gates.size(2)){ + const auto d_output_gate = tanh(new_cell[n][c]) * grad_h[n][c]; + const auto d_tanh_new_cell = output_gate[n][c] * grad_h[n][c]; + const auto d_new_cell = + d_tanh(new_cell[n][c]) * d_tanh_new_cell + grad_cell[n][c]; + + + d_old_cell[n][c] = d_new_cell; + const auto d_candidate_cell = input_gate[n][c] * d_new_cell; + const auto d_input_gate = candidate_cell[n][c] * d_new_cell; + + d_gates[n][0][c] = + d_input_gate * d_sigmoid(gate_weights[n][0][c]); + d_gates[n][1][c] = + d_output_gate * d_sigmoid(gate_weights[n][1][c]); + d_gates[n][2][c] = + d_candidate_cell * d_elu(gate_weights[n][2][c]); + } +} +} // namespace + +std::vector lltm_cuda_forward( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + torch::Tensor old_h, + torch::Tensor old_cell) { + auto X = torch::cat({old_h, input}, /*dim=*/1); + auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); + + const auto batch_size = old_cell.size(0); + const auto state_size = old_cell.size(1); + + auto gates = gate_weights.reshape({batch_size, 3, state_size}); + auto new_h = torch::zeros_like(old_cell); + auto new_cell = torch::zeros_like(old_cell); + auto input_gate = torch::zeros_like(old_cell); + auto output_gate = torch::zeros_like(old_cell); + auto candidate_cell = torch::zeros_like(old_cell); + + const int threads = 1024; + const dim3 blocks((state_size + threads - 1) / threads, batch_size); + + AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] { + lltm_cuda_forward_kernel<<>>( + gates.packed_accessor(), + old_cell.packed_accessor(), + new_h.packed_accessor(), + new_cell.packed_accessor(), + input_gate.packed_accessor(), + output_gate.packed_accessor(), + candidate_cell.packed_accessor()); + })); + + return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates}; +} + +std::vector lltm_cuda_backward( + torch::Tensor grad_h, + torch::Tensor grad_cell, + torch::Tensor new_cell, + torch::Tensor input_gate, + torch::Tensor output_gate, + torch::Tensor candidate_cell, + torch::Tensor X, + torch::Tensor gates, + torch::Tensor weights) { + auto d_old_cell = torch::zeros_like(new_cell); + auto d_gates = torch::zeros_like(gates); + + const auto batch_size = new_cell.size(0); + const auto state_size = new_cell.size(1); + + const int threads = 1024; + const dim3 blocks((state_size + threads - 1) / threads, batch_size); + + AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] { + lltm_cuda_backward_kernel<<>>( + d_old_cell.packed_accessor(), + d_gates.packed_accessor(), + grad_h.packed_accessor(), + grad_cell.packed_accessor(), + new_cell.packed_accessor(), + input_gate.packed_accessor(), + output_gate.packed_accessor(), + candidate_cell.packed_accessor(), + gates.packed_accessor()); + })); + + auto d_gate_weights = d_gates.flatten(1, 2); + auto d_weights = d_gate_weights.t().mm(X); + auto d_bias = d_gate_weights.sum(/*dim=*/0, /*keepdim=*/true); + + auto d_X = d_gate_weights.mm(weights); + auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); + auto d_input = d_X.slice(/*dim=*/1, state_size); + + return {d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates}; +} diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index b1ed11eeff..5338f22655 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -10,15 +10,19 @@ # limitations under the License. from typing import Sequence, Union - +import math import torch -import torch.nn as nn +from torch import nn +from torch.autograd import Function import torch.nn.functional as F - from monai.networks.layers.convutils import gaussian_1d, same_padding from monai.utils import ensure_tuple_rep +from monai.utils import optional_import -__all__ = ["SkipConnection", "Flatten", "GaussianFilter"] +_C, _ = optional_import("monai._C") +_C_CUDA, _ = optional_import("monai._C_CUDA") + +__all__ = ["SkipConnection", "Flatten", "GaussianFilter", "LLTM"] class SkipConnection(nn.Module): @@ -113,3 +117,58 @@ def _conv(input_: torch.Tensor, d: int) -> torch.Tensor: return self.conv_n(input=_conv(input_, d - 1), weight=kernel, padding=padding, groups=chns) return _conv(x, sp_dim - 1) + + +class LLTMFunction(Function): + @staticmethod + def forward(ctx, input, weights, bias, old_h, old_cell): + ext = _C_CUDA if weights.is_cuda else _C + outputs = ext.lltm_forward(input, weights, bias, old_h, old_cell) + new_h, new_cell = outputs[:2] + variables = outputs[1:] + [weights] + ctx.save_for_backward(*variables) + + return new_h, new_cell + + @staticmethod + def backward(ctx, grad_h, grad_cell): + if grad_h.is_cuda: + outputs = _C_CUDA.lltm_backward(grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_tensors) + d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates = outputs + else: + outputs = _C.lltm_backward(grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_tensors) + d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs + + return d_input, d_weights, d_bias, d_old_h, d_old_cell + + +class LLTM(nn.Module): + """ + This recurrent unit is similar to an LSTM, but differs in that it lacks a forget + gate and uses an Exponential Linear Unit (ELU) as its internal activation function. + Because this unit never forgets, call it LLTM, or Long-Long-Term-Memory unit. + It has both C++ and CUDA implementation, automatically switch according to the + target device where put this module to. + + Args: + input_features: size of input feature data + state_size: size of the state of recurrent unit + + Referring to: https://pytorch.org/tutorials/advanced/cpp_extension.html + """ + + def __init__(self, input_features: int, state_size: int): + super(LLTM, self).__init__() + self.input_features = input_features + self.state_size = state_size + self.weights = nn.Parameter(torch.empty(3 * state_size, input_features + state_size)) + self.bias = nn.Parameter(torch.empty(1, 3 * state_size)) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1.0 / math.sqrt(self.state_size) + for weight in self.parameters(): + weight.data.uniform_(-stdv, +stdv) + + def forward(self, input, state): + return LLTMFunction.apply(input, self.weights, self.bias, *state) diff --git a/pyproject.toml b/pyproject.toml index 36ececc145..9c49e77242 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,11 @@ +[build-system] +requires = [ + "wheel", + "setuptools", + "torch>=1.4", + "ninja", +] + [tool.black] line-length = 120 target-version = ['py36', 'py37', 'py38'] diff --git a/requirements-dev.txt b/requirements-dev.txt index 50dcde1ee3..36be129174 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -22,3 +22,4 @@ black isort pytype>=2020.6.1 mypy +ninja diff --git a/runtests.sh b/runtests.sh index d7657d0ee1..4e848d7161 100755 --- a/runtests.sh +++ b/runtests.sh @@ -45,10 +45,11 @@ function print_usage { echo "MONAI unit testing utilities." echo "" echo "Examples:" - echo "./runtests.sh --codeformat --coverage # runs full tests (${green}recommended before making pull requests${noColor})." - echo "./runtests.sh --codeformat --nounittests # runs coding style and static type checking." - echo "./runtests.sh --quick # runs minimal unit tests, for quick verification during code developments." - echo "./runtests.sh --black-fix # runs automatic code formatting using \"black\"." + echo "./runtests.sh --codeformat --coverage # run full tests (${green}recommended before making pull requests${noColor})." + echo "./runtests.sh --codeformat --nounittests # run coding style and static type checking." + echo "./runtests.sh --quick # run minimal unit tests, for quick verification during code developments." + echo "./runtests.sh --black-fix # run automatic code formatting using \"black\"." + echo "./runtests.sh --clean # clean up temporary files and run \"python setup.py develop --uninstall\"." echo "" echo "Code style check options:" echo " --black : perform \"black\" code format checks" @@ -86,8 +87,40 @@ function print_version { } function install_deps { - echo "Pip installing MONAI development dependencies..." - ${cmdPrefix}pip install -r requirements-dev.txt + echo "Pip installing MONAI development dependencies and compile MONAI cpp extensions..." + ${cmdPrefix}python -m pip install -r requirements-dev.txt +} + +function compile_cpp { + echo "Compiling and installing MONAI cpp extensions..." + ${cmdPrefix}python setup.py -v develop --uninstall + if [[ "$OSTYPE" == "darwin"* ]]; + then # clang for mac os + CC=clang CXX=clang++ ${cmdPrefix}python setup.py -v develop + else + ${cmdPrefix}python setup.py -v develop + fi +} + +function clean_py() { + # uninstall the development package + echo "Uninstalling MONAI development files..." + ${cmdPrefix}python setup.py -v develop --uninstall + + # remove temporary files + echo "Removing temporary files..." + TO_CLEAN=${*:-'.'} + find ${TO_CLEAN} -type f -name "*.py[co]" -delete + find ${TO_CLEAN} -type f -name ".coverage" -delete + find ${TO_CLEAN} -type d -name "__pycache__" -delete + + find ${TO_CLEAN} -depth -type d -name ".eggs" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -type d -name "monai.egg-info" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -type d -name "build" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -type d -name "dist" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -type d -name ".mypy_cache" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -type d -name ".pytype" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -type d -name ".coverage" -exec rm -r "{}" + } function torch_validate { @@ -197,42 +230,21 @@ then function dryrun { echo " " "$@"; } fi -# unconditionally report on the state of monai -print_version - - if [ $doCleanup = true ] then echo "${separator}${blue}clean${noColor}" - if [ -d .mypy_cache ] - then - ${cmdPrefix}rm -r .mypy_cache - elif [ -f .mypy_cache ] - then - ${cmdPrefix}rm .mypy_cache - fi - - if [ -d .pytype ] - then - ${cmdPrefix}rm -r .pytype - elif [ -f .pytype ] - then - ${cmdPrefix}rm .pytype - fi - - if [ -d .coverage ] - then - ${cmdPrefix}rm -r .coverage - elif [ -f .coverage ] - then - ${cmdPrefix}rm .coverage - fi + clean_py echo "${green}done!${noColor}" exit fi +# try to compile MONAI cpp +compile_cpp + +# unconditionally report on the state of monai +print_version if [ $doBlackFormat = true ] then diff --git a/setup.cfg b/setup.cfg index dee943d076..24a7de1b14 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,6 +11,9 @@ license = Apache License 2.0 [options] python_requires = >= 3.6 +setup_requires = + torch>=1.4 + ninja install_requires = torch>=1.4 numpy>=1.17 @@ -48,13 +51,13 @@ ignore = # N812 lowercase 'torch.nn.functional' imported as non lowercase 'F' N812 per-file-ignores = __init__.py: F401 -exclude = *.pyi,.git,monai/_version.py,versioneer.py,venv, _version.py +exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,_version.py [isort] known_first_party = monai profile = black line_length = 120 -skip = .git, venv, versioneer.py, _version.py, conf.py +skip = .git, .eggs, venv, versioneer.py, _version.py, conf.py skip_glob = *.pyi [versioneer] @@ -108,11 +111,15 @@ ignore_errors = True # Always ignore any type issues in the monai/._version.py file ignore_errors = True +[mypy-monai.eggs] +# Always ignore any type issues in the monai/.eggs file +ignore_errors = True + [pytype] # NOTE: All relative paths are relative to the location of this file. # Space-separated list of files or directories to exclude. # i.e. exclude = **/*_test.py **/test_*.py -exclude = **/versioneer.py _version.py +exclude = **/versioneer.py _version.py .eggs # Space-separated list of files or directories to process. inputs = monai # Keep going past errors to analyze as many files as possible. diff --git a/setup.py b/setup.py index 5158fa1fb9..aaced090b0 100644 --- a/setup.py +++ b/setup.py @@ -9,15 +9,51 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings + from setuptools import find_packages, setup import versioneer -if __name__ == "__main__": - setup( - version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), - packages=find_packages(exclude=("docs", "examples", "tests", "research")), - zip_safe=False, - package_data={"monai": ["py.typed"]}, - ) + +def get_extensions(): + + try: + import torch + from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension + + print(f"setup.py with torch {torch.__version__}") + except ImportError: + warnings.warn("torch cpp/cuda building skipped.") + return [] + + ext_modules = [CppExtension("monai._C", ["monai/networks/extensions/lltm/lltm.cpp"])] + if torch.cuda.is_available() and (CUDA_HOME is not None): + ext_modules.append( + CUDAExtension( + "monai._C_CUDA", + ["monai/networks/extensions/lltm/lltm_cuda.cpp", "monai/networks/extensions/lltm/lltm_cuda_kernel.cu"], + ) + ) + return ext_modules + + +def get_cmds(): + cmds = versioneer.get_cmdclass() + try: + from torch.utils.cpp_extension import BuildExtension + + cmds.update({"build_ext": BuildExtension}) + except ImportError: + warnings.warn("torch cpp_extension module not found.") + return cmds + + +setup( + version=versioneer.get_version(), + cmdclass=get_cmds(), + packages=find_packages(exclude=("docs", "examples", "tests", "research")), + zip_safe=False, + package_data={"monai": ["py.typed"]}, + ext_modules=get_extensions(), +) diff --git a/tests/test_lltm.py b/tests/test_lltm.py new file mode 100644 index 0000000000..5c666a8794 --- /dev/null +++ b/tests/test_lltm.py @@ -0,0 +1,55 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.layers import LLTM + +TEST_CASE_1 = [ + {"input_features": 32, "state_size": 2}, + torch.tensor([[-0.1622, 0.1663], [0.5465, 0.0459], [-0.1436, 0.6171], [0.3632, -0.0111]]), + torch.tensor([[-1.3773, 0.3348], [0.8353, 1.3064], [-0.2179, 4.1739], [1.3045, -0.1444]]), +] + + +class TestLLTM(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + def test_value(self, input_param, expected_h, expected_c): + torch.manual_seed(0) + x = torch.randn(4, 32) + h = torch.randn(4, 2) + c = torch.randn(4, 2) + new_h, new_c = LLTM(**input_param)(x, (h, c)) + (new_h.sum() + new_c.sum()).backward() + + torch.testing.assert_allclose(new_h, expected_h, rtol=0.0001, atol=1e-04) + torch.testing.assert_allclose(new_c, expected_c, rtol=0.0001, atol=1e-04) + + @parameterized.expand([TEST_CASE_1]) + def test_value_cuda(self, input_param, expected_h, expected_c): + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0") + torch.manual_seed(0) + x = torch.randn(4, 32).to(device) + h = torch.randn(4, 2).to(device) + c = torch.randn(4, 2).to(device) + lltm = LLTM(**input_param).to(device) + new_h, new_c = lltm(x, (h, c)) + (new_h.sum() + new_c.sum()).backward() + + torch.testing.assert_allclose(new_h, expected_h.to(device), rtol=0.0001, atol=1e-04) + torch.testing.assert_allclose(new_c, expected_c.to(device), rtol=0.0001, atol=1e-04) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/testing_data/._Task04_Hippocampus b/tests/testing_data/._Task04_Hippocampus deleted file mode 100755 index 7e298b58a4..0000000000 Binary files a/tests/testing_data/._Task04_Hippocampus and /dev/null differ