From 57d3f61db14ad05365b48bb0e34e07f8bbfa4731 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 20 Jul 2020 11:19:51 +0800 Subject: [PATCH 01/22] [DLMED] add CUDA extension test program --- tests/test_cuda_extension/__init__.py | 0 tests/test_cuda_extension/jit.py | 4 + tests/test_cuda_extension/lltm.py | 45 +++++ tests/test_cuda_extension/lltm_cuda.cpp | 81 ++++++++ tests/test_cuda_extension/lltm_cuda_kernel.cu | 174 ++++++++++++++++++ tests/test_cuda_extension/setup.py | 14 ++ 6 files changed, 318 insertions(+) create mode 100644 tests/test_cuda_extension/__init__.py create mode 100644 tests/test_cuda_extension/jit.py create mode 100644 tests/test_cuda_extension/lltm.py create mode 100644 tests/test_cuda_extension/lltm_cuda.cpp create mode 100644 tests/test_cuda_extension/lltm_cuda_kernel.cu create mode 100644 tests/test_cuda_extension/setup.py diff --git a/tests/test_cuda_extension/__init__.py b/tests/test_cuda_extension/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_cuda_extension/jit.py b/tests/test_cuda_extension/jit.py new file mode 100644 index 0000000000..6c52efff8f --- /dev/null +++ b/tests/test_cuda_extension/jit.py @@ -0,0 +1,4 @@ +from torch.utils.cpp_extension import load +lltm_cuda = load( + 'lltm_cuda', ['lltm_cuda.cpp', 'lltm_cuda_kernel.cu'], verbose=True) +help(lltm_cuda) diff --git a/tests/test_cuda_extension/lltm.py b/tests/test_cuda_extension/lltm.py new file mode 100644 index 0000000000..c740b8864f --- /dev/null +++ b/tests/test_cuda_extension/lltm.py @@ -0,0 +1,45 @@ +import math +from torch import nn +from torch.autograd import Function +import torch + +import lltm_cuda + +torch.manual_seed(42) + + +class LLTMFunction(Function): + @staticmethod + def forward(ctx, input, weights, bias, old_h, old_cell): + outputs = lltm_cuda.forward(input, weights, bias, old_h, old_cell) + new_h, new_cell = outputs[:2] + variables = outputs[1:] + [weights] + ctx.save_for_backward(*variables) + + return new_h, new_cell + + @staticmethod + def backward(ctx, grad_h, grad_cell): + outputs = lltm_cuda.backward( + grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_variables) + d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates = outputs + return d_input, d_weights, d_bias, d_old_h, d_old_cell + + +class LLTM(nn.Module): + def __init__(self, input_features, state_size): + super(LLTM, self).__init__() + self.input_features = input_features + self.state_size = state_size + self.weights = nn.Parameter( + torch.Tensor(3 * state_size, input_features + state_size)) + self.bias = nn.Parameter(torch.Tensor(1, 3 * state_size)) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1.0 / math.sqrt(self.state_size) + for weight in self.parameters(): + weight.data.uniform_(-stdv, +stdv) + + def forward(self, input, state): + return LLTMFunction.apply(input, self.weights, self.bias, *state) diff --git a/tests/test_cuda_extension/lltm_cuda.cpp b/tests/test_cuda_extension/lltm_cuda.cpp new file mode 100644 index 0000000000..2434776abd --- /dev/null +++ b/tests/test_cuda_extension/lltm_cuda.cpp @@ -0,0 +1,81 @@ +#include + +#include + +// CUDA forward declarations + +std::vector lltm_cuda_forward( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + torch::Tensor old_h, + torch::Tensor old_cell); + +std::vector lltm_cuda_backward( + torch::Tensor grad_h, + torch::Tensor grad_cell, + torch::Tensor new_cell, + torch::Tensor input_gate, + torch::Tensor output_gate, + torch::Tensor candidate_cell, + torch::Tensor X, + torch::Tensor gate_weights, + torch::Tensor weights); + +// C++ interface + +// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::vector lltm_forward( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + torch::Tensor old_h, + torch::Tensor old_cell) { + CHECK_INPUT(input); + CHECK_INPUT(weights); + CHECK_INPUT(bias); + CHECK_INPUT(old_h); + CHECK_INPUT(old_cell); + + return lltm_cuda_forward(input, weights, bias, old_h, old_cell); +} + +std::vector lltm_backward( + torch::Tensor grad_h, + torch::Tensor grad_cell, + torch::Tensor new_cell, + torch::Tensor input_gate, + torch::Tensor output_gate, + torch::Tensor candidate_cell, + torch::Tensor X, + torch::Tensor gate_weights, + torch::Tensor weights) { + CHECK_INPUT(grad_h); + CHECK_INPUT(grad_cell); + CHECK_INPUT(input_gate); + CHECK_INPUT(output_gate); + CHECK_INPUT(candidate_cell); + CHECK_INPUT(X); + CHECK_INPUT(gate_weights); + CHECK_INPUT(weights); + + return lltm_cuda_backward( + grad_h, + grad_cell, + new_cell, + input_gate, + output_gate, + candidate_cell, + X, + gate_weights, + weights); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &lltm_forward, "LLTM forward (CUDA)"); + m.def("backward", &lltm_backward, "LLTM backward (CUDA)"); +} diff --git a/tests/test_cuda_extension/lltm_cuda_kernel.cu b/tests/test_cuda_extension/lltm_cuda_kernel.cu new file mode 100644 index 0000000000..02bb9ade5d --- /dev/null +++ b/tests/test_cuda_extension/lltm_cuda_kernel.cu @@ -0,0 +1,174 @@ +#include + +#include +#include + +#include + +namespace { +template +__device__ __forceinline__ scalar_t sigmoid(scalar_t z) { + return 1.0 / (1.0 + exp(-z)); +} + +template +__device__ __forceinline__ scalar_t d_sigmoid(scalar_t z) { + const auto s = sigmoid(z); + return (1.0 - s) * s; +} + +template +__device__ __forceinline__ scalar_t d_tanh(scalar_t z) { + const auto t = tanh(z); + return 1 - (t * t); +} + +template +__device__ __forceinline__ scalar_t elu(scalar_t z, scalar_t alpha = 1.0) { + return fmaxf(0.0, z) + fminf(0.0, alpha * (exp(z) - 1.0)); +} + +template +__device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) { + const auto e = exp(z); + const auto d_relu = z < 0.0 ? 0.0 : 1.0; + return d_relu + (((alpha * (e - 1.0)) < 0.0) ? (alpha * e) : 0.0); +} + +template +__global__ void lltm_cuda_forward_kernel( + const torch::PackedTensorAccessor gates, + const torch::PackedTensorAccessor old_cell, + torch::PackedTensorAccessor new_h, + torch::PackedTensorAccessor new_cell, + torch::PackedTensorAccessor input_gate, + torch::PackedTensorAccessor output_gate, + torch::PackedTensorAccessor candidate_cell) { + //batch index + const int n = blockIdx.y; + // column index + const int c = blockIdx.x * blockDim.x + threadIdx.x; + if (c < gates.size(2)){ + input_gate[n][c] = sigmoid(gates[n][0][c]); + output_gate[n][c] = sigmoid(gates[n][1][c]); + candidate_cell[n][c] = elu(gates[n][2][c]); + new_cell[n][c] = + old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c]; + new_h[n][c] = tanh(new_cell[n][c]) * output_gate[n][c]; + } +} + +template +__global__ void lltm_cuda_backward_kernel( + torch::PackedTensorAccessor d_old_cell, + torch::PackedTensorAccessor d_gates, + const torch::PackedTensorAccessor grad_h, + const torch::PackedTensorAccessor grad_cell, + const torch::PackedTensorAccessor new_cell, + const torch::PackedTensorAccessor input_gate, + const torch::PackedTensorAccessor output_gate, + const torch::PackedTensorAccessor candidate_cell, + const torch::PackedTensorAccessor gate_weights) { + //batch index + const int n = blockIdx.y; + // column index + const int c = blockIdx.x * blockDim.x + threadIdx.x; + if (c < d_gates.size(2)){ + const auto d_output_gate = tanh(new_cell[n][c]) * grad_h[n][c]; + const auto d_tanh_new_cell = output_gate[n][c] * grad_h[n][c]; + const auto d_new_cell = + d_tanh(new_cell[n][c]) * d_tanh_new_cell + grad_cell[n][c]; + + + d_old_cell[n][c] = d_new_cell; + const auto d_candidate_cell = input_gate[n][c] * d_new_cell; + const auto d_input_gate = candidate_cell[n][c] * d_new_cell; + + d_gates[n][0][c] = + d_input_gate * d_sigmoid(gate_weights[n][0][c]); + d_gates[n][1][c] = + d_output_gate * d_sigmoid(gate_weights[n][1][c]); + d_gates[n][2][c] = + d_candidate_cell * d_elu(gate_weights[n][2][c]); + } +} +} // namespace + +std::vector lltm_cuda_forward( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + torch::Tensor old_h, + torch::Tensor old_cell) { + auto X = torch::cat({old_h, input}, /*dim=*/1); + auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); + + const auto batch_size = old_cell.size(0); + const auto state_size = old_cell.size(1); + + auto gates = gate_weights.reshape({batch_size, 3, state_size}); + auto new_h = torch::zeros_like(old_cell); + auto new_cell = torch::zeros_like(old_cell); + auto input_gate = torch::zeros_like(old_cell); + auto output_gate = torch::zeros_like(old_cell); + auto candidate_cell = torch::zeros_like(old_cell); + + const int threads = 1024; + const dim3 blocks((state_size + threads - 1) / threads, batch_size); + + AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] { + lltm_cuda_forward_kernel<<>>( + gates.packed_accessor(), + old_cell.packed_accessor(), + new_h.packed_accessor(), + new_cell.packed_accessor(), + input_gate.packed_accessor(), + output_gate.packed_accessor(), + candidate_cell.packed_accessor()); + })); + + return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates}; +} + +std::vector lltm_cuda_backward( + torch::Tensor grad_h, + torch::Tensor grad_cell, + torch::Tensor new_cell, + torch::Tensor input_gate, + torch::Tensor output_gate, + torch::Tensor candidate_cell, + torch::Tensor X, + torch::Tensor gates, + torch::Tensor weights) { + auto d_old_cell = torch::zeros_like(new_cell); + auto d_gates = torch::zeros_like(gates); + + const auto batch_size = new_cell.size(0); + const auto state_size = new_cell.size(1); + + const int threads = 1024; + const dim3 blocks((state_size + threads - 1) / threads, batch_size); + + AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] { + lltm_cuda_backward_kernel<<>>( + d_old_cell.packed_accessor(), + d_gates.packed_accessor(), + grad_h.packed_accessor(), + grad_cell.packed_accessor(), + new_cell.packed_accessor(), + input_gate.packed_accessor(), + output_gate.packed_accessor(), + candidate_cell.packed_accessor(), + gates.packed_accessor()); + })); + + auto d_gate_weights = d_gates.flatten(1, 2); + auto d_weights = d_gate_weights.t().mm(X); + auto d_bias = d_gate_weights.sum(/*dim=*/0, /*keepdim=*/true); + + auto d_X = d_gate_weights.mm(weights); + auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); + auto d_input = d_X.slice(/*dim=*/1, state_size); + + return {d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates}; +} diff --git a/tests/test_cuda_extension/setup.py b/tests/test_cuda_extension/setup.py new file mode 100644 index 0000000000..670b3c86e6 --- /dev/null +++ b/tests/test_cuda_extension/setup.py @@ -0,0 +1,14 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='lltm_cuda', + ext_modules=[ + CUDAExtension('lltm_cuda', [ + 'lltm_cuda.cpp', + 'lltm_cuda_kernel.cu', + ]), + ], + cmdclass={ + 'build_ext': BuildExtension + }) From 6a7b337aa6f63208cd2cd2380d0a5734172969a3 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 28 Jul 2020 07:47:46 +0800 Subject: [PATCH 02/22] [DLMED] add unit tests and enable CI --- .github/workflows/pythonapp.yml | 3 + monai/networks/extensions/__init__.py | 12 ++ monai/networks/extensions/lltm/__init__.py | 10 ++ monai/networks/extensions/lltm/lltm.cpp | 103 ++++++++++++++++++ monai/networks/extensions/lltm/lltm.py | 61 +++++++++++ .../networks/extensions/lltm}/lltm_cuda.cpp | 13 +++ .../extensions/lltm}/lltm_cuda_kernel.cu | 13 +++ setup_cpp.py | 23 ++++ setup_cuda.py | 25 +++++ tests/test_cuda_extension/__init__.py | 0 tests/test_cuda_extension/jit.py | 4 - tests/test_cuda_extension/lltm.py | 45 -------- tests/test_cuda_extension/setup.py | 14 --- tests/test_lltm.py | 55 ++++++++++ 14 files changed, 318 insertions(+), 63 deletions(-) create mode 100644 monai/networks/extensions/__init__.py create mode 100644 monai/networks/extensions/lltm/__init__.py create mode 100644 monai/networks/extensions/lltm/lltm.cpp create mode 100644 monai/networks/extensions/lltm/lltm.py rename {tests/test_cuda_extension => monai/networks/extensions/lltm}/lltm_cuda.cpp (78%) rename {tests/test_cuda_extension => monai/networks/extensions/lltm}/lltm_cuda_kernel.cu (92%) create mode 100644 setup_cpp.py create mode 100644 setup_cuda.py delete mode 100644 tests/test_cuda_extension/__init__.py delete mode 100644 tests/test_cuda_extension/jit.py delete mode 100644 tests/test_cuda_extension/lltm.py delete mode 100644 tests/test_cuda_extension/setup.py create mode 100644 tests/test_lltm.py diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 3d1683d470..eb4338be57 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -63,6 +63,7 @@ jobs: cat "requirements-dev.txt" python -m pip install -r requirements-dev.txt python -m pip list + python setup_cpp.py install - name: Run quick tests (CPU ${{ runner.os }}) run: | python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' @@ -84,6 +85,8 @@ jobs: python -m pip uninstall -y torch torchvision python -m pip install torch==1.4 python -m pip install -r requirements-dev.txt + python setup_cpp.py install + python setup_cuda.py install - name: Run quick tests (GPU) run: | python -m pip list diff --git a/monai/networks/extensions/__init__.py b/monai/networks/extensions/__init__.py new file mode 100644 index 0000000000..899ced6864 --- /dev/null +++ b/monai/networks/extensions/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .lltm.lltm import LLTM diff --git a/monai/networks/extensions/lltm/__init__.py b/monai/networks/extensions/lltm/__init__.py new file mode 100644 index 0000000000..d0044e3563 --- /dev/null +++ b/monai/networks/extensions/lltm/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/networks/extensions/lltm/lltm.cpp b/monai/networks/extensions/lltm/lltm.cpp new file mode 100644 index 0000000000..5d098b461c --- /dev/null +++ b/monai/networks/extensions/lltm/lltm.cpp @@ -0,0 +1,103 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include + +#include + +// s'(z) = (1 - s(z)) * s(z) +torch::Tensor d_sigmoid(torch::Tensor z) { + auto s = torch::sigmoid(z); + return (1 - s) * s; +} + +// tanh'(z) = 1 - tanh^2(z) +torch::Tensor d_tanh(torch::Tensor z) { + return 1 - z.tanh().pow(2); +} + +// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0} +torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) { + auto e = z.exp(); + auto mask = (alpha * (e - 1)) < 0; + return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e); +} + +std::vector lltm_forward( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + torch::Tensor old_h, + torch::Tensor old_cell) { + auto X = torch::cat({old_h, input}, /*dim=*/1); + + auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); + auto gates = gate_weights.chunk(3, /*dim=*/1); + + auto input_gate = torch::sigmoid(gates[0]); + auto output_gate = torch::sigmoid(gates[1]); + auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0); + + auto new_cell = old_cell + candidate_cell * input_gate; + auto new_h = torch::tanh(new_cell) * output_gate; + + return {new_h, + new_cell, + input_gate, + output_gate, + candidate_cell, + X, + gate_weights}; +} + +std::vector lltm_backward( + torch::Tensor grad_h, + torch::Tensor grad_cell, + torch::Tensor new_cell, + torch::Tensor input_gate, + torch::Tensor output_gate, + torch::Tensor candidate_cell, + torch::Tensor X, + torch::Tensor gate_weights, + torch::Tensor weights) { + auto d_output_gate = torch::tanh(new_cell) * grad_h; + auto d_tanh_new_cell = output_gate * grad_h; + auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell; + + auto d_old_cell = d_new_cell; + auto d_candidate_cell = input_gate * d_new_cell; + auto d_input_gate = candidate_cell * d_new_cell; + + auto gates = gate_weights.chunk(3, /*dim=*/1); + d_input_gate *= d_sigmoid(gates[0]); + d_output_gate *= d_sigmoid(gates[1]); + d_candidate_cell *= d_elu(gates[2]); + + auto d_gates = + torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1); + + auto d_weights = d_gates.t().mm(X); + auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true); + + auto d_X = d_gates.mm(weights); + const auto state_size = grad_h.size(1); + auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); + auto d_input = d_X.slice(/*dim=*/1, state_size); + + return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &lltm_forward, "LLTM forward"); + m.def("backward", &lltm_backward, "LLTM backward"); +} diff --git a/monai/networks/extensions/lltm/lltm.py b/monai/networks/extensions/lltm/lltm.py new file mode 100644 index 0000000000..ea724c0248 --- /dev/null +++ b/monai/networks/extensions/lltm/lltm.py @@ -0,0 +1,61 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from torch import nn +from torch.autograd import Function +import torch + +from monai.utils import optional_import +lltm_cpp, _ = optional_import("lltm_cpp") +lltm_cuda, _ = optional_import("lltm_cuda") + + +class LLTMFunction(Function): + @staticmethod + def forward(ctx, input, weights, bias, old_h, old_cell): + ext = lltm_cuda if weights.is_cuda else lltm_cpp + outputs = ext.forward(input, weights, bias, old_h, old_cell) + new_h, new_cell = outputs[:2] + variables = outputs[1:] + [weights] + ctx.save_for_backward(*variables) + + return new_h, new_cell + + @staticmethod + def backward(ctx, grad_h, grad_cell): + if grad_h.is_cuda: + outputs = lltm_cuda.backward(grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_tensors) + d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates = outputs + else: + outputs = lltm_cpp.backward(grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_tensors) + d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs + + return d_input, d_weights, d_bias, d_old_h, d_old_cell + + +class LLTM(nn.Module): + def __init__(self, input_features, state_size): + super(LLTM, self).__init__() + self.input_features = input_features + self.state_size = state_size + self.weights = nn.Parameter( + torch.Tensor(3 * state_size, input_features + state_size)) + self.bias = nn.Parameter(torch.Tensor(1, 3 * state_size)) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1.0 / math.sqrt(self.state_size) + for weight in self.parameters(): + weight.data.uniform_(-stdv, +stdv) + + def forward(self, input, state): + return LLTMFunction.apply(input, self.weights, self.bias, *state) diff --git a/tests/test_cuda_extension/lltm_cuda.cpp b/monai/networks/extensions/lltm/lltm_cuda.cpp similarity index 78% rename from tests/test_cuda_extension/lltm_cuda.cpp rename to monai/networks/extensions/lltm/lltm_cuda.cpp index 2434776abd..b5db580d8c 100644 --- a/tests/test_cuda_extension/lltm_cuda.cpp +++ b/monai/networks/extensions/lltm/lltm_cuda.cpp @@ -1,3 +1,16 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + #include #include diff --git a/tests/test_cuda_extension/lltm_cuda_kernel.cu b/monai/networks/extensions/lltm/lltm_cuda_kernel.cu similarity index 92% rename from tests/test_cuda_extension/lltm_cuda_kernel.cu rename to monai/networks/extensions/lltm/lltm_cuda_kernel.cu index 02bb9ade5d..dd9aeeb024 100644 --- a/tests/test_cuda_extension/lltm_cuda_kernel.cu +++ b/monai/networks/extensions/lltm/lltm_cuda_kernel.cu @@ -1,3 +1,16 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + #include #include diff --git a/setup_cpp.py b/setup_cpp.py new file mode 100644 index 0000000000..c7e91faafb --- /dev/null +++ b/setup_cpp.py @@ -0,0 +1,23 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CppExtension + +setup( + name="lltm_cpp", + ext_modules=[ + CppExtension("lltm_cpp", ["monai/networks/extensions/lltm/lltm.cpp"]), + ], + cmdclass={ + "build_ext": BuildExtension + }) diff --git a/setup_cuda.py b/setup_cuda.py new file mode 100644 index 0000000000..611d38debd --- /dev/null +++ b/setup_cuda.py @@ -0,0 +1,25 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name="lltm_cuda", + ext_modules=[ + CUDAExtension("lltm_cuda", [ + "monai/networks/extensions/lltm/lltm_cuda.cpp", + "monai/networks/extensions/lltm/lltm_cuda_kernel.cu", + ]), + ], + cmdclass={ + "build_ext": BuildExtension + }) diff --git a/tests/test_cuda_extension/__init__.py b/tests/test_cuda_extension/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/test_cuda_extension/jit.py b/tests/test_cuda_extension/jit.py deleted file mode 100644 index 6c52efff8f..0000000000 --- a/tests/test_cuda_extension/jit.py +++ /dev/null @@ -1,4 +0,0 @@ -from torch.utils.cpp_extension import load -lltm_cuda = load( - 'lltm_cuda', ['lltm_cuda.cpp', 'lltm_cuda_kernel.cu'], verbose=True) -help(lltm_cuda) diff --git a/tests/test_cuda_extension/lltm.py b/tests/test_cuda_extension/lltm.py deleted file mode 100644 index c740b8864f..0000000000 --- a/tests/test_cuda_extension/lltm.py +++ /dev/null @@ -1,45 +0,0 @@ -import math -from torch import nn -from torch.autograd import Function -import torch - -import lltm_cuda - -torch.manual_seed(42) - - -class LLTMFunction(Function): - @staticmethod - def forward(ctx, input, weights, bias, old_h, old_cell): - outputs = lltm_cuda.forward(input, weights, bias, old_h, old_cell) - new_h, new_cell = outputs[:2] - variables = outputs[1:] + [weights] - ctx.save_for_backward(*variables) - - return new_h, new_cell - - @staticmethod - def backward(ctx, grad_h, grad_cell): - outputs = lltm_cuda.backward( - grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_variables) - d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates = outputs - return d_input, d_weights, d_bias, d_old_h, d_old_cell - - -class LLTM(nn.Module): - def __init__(self, input_features, state_size): - super(LLTM, self).__init__() - self.input_features = input_features - self.state_size = state_size - self.weights = nn.Parameter( - torch.Tensor(3 * state_size, input_features + state_size)) - self.bias = nn.Parameter(torch.Tensor(1, 3 * state_size)) - self.reset_parameters() - - def reset_parameters(self): - stdv = 1.0 / math.sqrt(self.state_size) - for weight in self.parameters(): - weight.data.uniform_(-stdv, +stdv) - - def forward(self, input, state): - return LLTMFunction.apply(input, self.weights, self.bias, *state) diff --git a/tests/test_cuda_extension/setup.py b/tests/test_cuda_extension/setup.py deleted file mode 100644 index 670b3c86e6..0000000000 --- a/tests/test_cuda_extension/setup.py +++ /dev/null @@ -1,14 +0,0 @@ -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -setup( - name='lltm_cuda', - ext_modules=[ - CUDAExtension('lltm_cuda', [ - 'lltm_cuda.cpp', - 'lltm_cuda_kernel.cu', - ]), - ], - cmdclass={ - 'build_ext': BuildExtension - }) diff --git a/tests/test_lltm.py b/tests/test_lltm.py new file mode 100644 index 0000000000..b4af394e1e --- /dev/null +++ b/tests/test_lltm.py @@ -0,0 +1,55 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.extensions import LLTM + +TEST_CASE_1 = [ + {"input_features": 32, "state_size": 2}, + torch.tensor([[-0.1622, 0.1663], [0.5465, 0.0459], [-0.1436, 0.6171], [0.3632, -0.0111]]), + torch.tensor([[-1.3773, 0.3348], [0.8353, 1.3064], [-0.2179, 4.1739], [1.3045, -0.1444]]), +] + + +class TestLLTM(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + def test_value(self, input_param, expected_h, expected_C): + torch.manual_seed(0) + X = torch.randn(4, 32) + h = torch.randn(4, 2) + C = torch.randn(4, 2) + new_h, new_C = LLTM(**input_param)(X, (h, C)) + (new_h.sum() + new_C.sum()).backward() + + torch.testing.assert_allclose(new_h, expected_h, rtol=0.0001, atol=1e-04) + torch.testing.assert_allclose(new_C, expected_C, rtol=0.0001, atol=1e-04) + + @parameterized.expand([TEST_CASE_1]) + def test_value_cuda(self, input_param, expected_h, expected_C): + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0") + torch.manual_seed(0) + X = torch.randn(4, 32).to(device) + h = torch.randn(4, 2).to(device) + C = torch.randn(4, 2).to(device) + lltm = LLTM(**input_param).to(device) + new_h, new_C = lltm(X, (h, C)) + (new_h.sum() + new_C.sum()).backward() + + torch.testing.assert_allclose(new_h, expected_h.to(device), rtol=0.0001, atol=1e-04) + torch.testing.assert_allclose(new_C, expected_C.to(device), rtol=0.0001, atol=1e-04) + + +if __name__ == "__main__": + unittest.main() From 3a7f5601736a9718811fc485a4af8f2428f82b37 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Mon, 27 Jul 2020 23:50:54 +0000 Subject: [PATCH 03/22] [MONAI] python code formatting --- monai/networks/extensions/lltm/lltm.py | 4 ++-- setup_cpp.py | 9 +++------ setup_cuda.py | 13 ++++++------- 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/monai/networks/extensions/lltm/lltm.py b/monai/networks/extensions/lltm/lltm.py index ea724c0248..c70ec47363 100644 --- a/monai/networks/extensions/lltm/lltm.py +++ b/monai/networks/extensions/lltm/lltm.py @@ -15,6 +15,7 @@ import torch from monai.utils import optional_import + lltm_cpp, _ = optional_import("lltm_cpp") lltm_cuda, _ = optional_import("lltm_cuda") @@ -47,8 +48,7 @@ def __init__(self, input_features, state_size): super(LLTM, self).__init__() self.input_features = input_features self.state_size = state_size - self.weights = nn.Parameter( - torch.Tensor(3 * state_size, input_features + state_size)) + self.weights = nn.Parameter(torch.Tensor(3 * state_size, input_features + state_size)) self.bias = nn.Parameter(torch.Tensor(1, 3 * state_size)) self.reset_parameters() diff --git a/setup_cpp.py b/setup_cpp.py index c7e91faafb..dd08ae8d49 100644 --- a/setup_cpp.py +++ b/setup_cpp.py @@ -15,9 +15,6 @@ setup( name="lltm_cpp", - ext_modules=[ - CppExtension("lltm_cpp", ["monai/networks/extensions/lltm/lltm.cpp"]), - ], - cmdclass={ - "build_ext": BuildExtension - }) + ext_modules=[CppExtension("lltm_cpp", ["monai/networks/extensions/lltm/lltm.cpp"]),], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/setup_cuda.py b/setup_cuda.py index 611d38debd..706f30f5a5 100644 --- a/setup_cuda.py +++ b/setup_cuda.py @@ -15,11 +15,10 @@ setup( name="lltm_cuda", ext_modules=[ - CUDAExtension("lltm_cuda", [ - "monai/networks/extensions/lltm/lltm_cuda.cpp", - "monai/networks/extensions/lltm/lltm_cuda_kernel.cu", - ]), + CUDAExtension( + "lltm_cuda", + ["monai/networks/extensions/lltm/lltm_cuda.cpp", "monai/networks/extensions/lltm/lltm_cuda_kernel.cu",], + ), ], - cmdclass={ - "build_ext": BuildExtension - }) + cmdclass={"build_ext": BuildExtension}, +) From 868c9efafc0400a5fbaf82b3794870360df4327b Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 28 Jul 2020 08:31:51 +0800 Subject: [PATCH 04/22] [DLMED] update flake8 --- setup_cpp.py | 2 +- setup_cuda.py | 2 +- tests/test_lltm.py | 24 ++++++++++++------------ 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/setup_cpp.py b/setup_cpp.py index dd08ae8d49..d4cb3f539c 100644 --- a/setup_cpp.py +++ b/setup_cpp.py @@ -15,6 +15,6 @@ setup( name="lltm_cpp", - ext_modules=[CppExtension("lltm_cpp", ["monai/networks/extensions/lltm/lltm.cpp"]),], + ext_modules=[CppExtension("lltm_cpp", ["monai/networks/extensions/lltm/lltm.cpp"])], cmdclass={"build_ext": BuildExtension}, ) diff --git a/setup_cuda.py b/setup_cuda.py index 706f30f5a5..83f7f09216 100644 --- a/setup_cuda.py +++ b/setup_cuda.py @@ -17,7 +17,7 @@ ext_modules=[ CUDAExtension( "lltm_cuda", - ["monai/networks/extensions/lltm/lltm_cuda.cpp", "monai/networks/extensions/lltm/lltm_cuda_kernel.cu",], + ["monai/networks/extensions/lltm/lltm_cuda.cpp", "monai/networks/extensions/lltm/lltm_cuda_kernel.cu"], ), ], cmdclass={"build_ext": BuildExtension}, diff --git a/tests/test_lltm.py b/tests/test_lltm.py index b4af394e1e..9e65db77fe 100644 --- a/tests/test_lltm.py +++ b/tests/test_lltm.py @@ -25,30 +25,30 @@ class TestLLTM(unittest.TestCase): @parameterized.expand([TEST_CASE_1]) - def test_value(self, input_param, expected_h, expected_C): + def test_value(self, input_param, expected_h, expected_c): torch.manual_seed(0) - X = torch.randn(4, 32) + x = torch.randn(4, 32) h = torch.randn(4, 2) - C = torch.randn(4, 2) - new_h, new_C = LLTM(**input_param)(X, (h, C)) - (new_h.sum() + new_C.sum()).backward() + c = torch.randn(4, 2) + new_h, new_c = LLTM(**input_param)(x, (h, c)) + (new_h.sum() + new_c.sum()).backward() torch.testing.assert_allclose(new_h, expected_h, rtol=0.0001, atol=1e-04) - torch.testing.assert_allclose(new_C, expected_C, rtol=0.0001, atol=1e-04) + torch.testing.assert_allclose(new_c, expected_c, rtol=0.0001, atol=1e-04) @parameterized.expand([TEST_CASE_1]) - def test_value_cuda(self, input_param, expected_h, expected_C): + def test_value_cuda(self, input_param, expected_h, expected_c): device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0") torch.manual_seed(0) - X = torch.randn(4, 32).to(device) + x = torch.randn(4, 32).to(device) h = torch.randn(4, 2).to(device) - C = torch.randn(4, 2).to(device) + c = torch.randn(4, 2).to(device) lltm = LLTM(**input_param).to(device) - new_h, new_C = lltm(X, (h, C)) - (new_h.sum() + new_C.sum()).backward() + new_h, new_c = lltm(x, (h, c)) + (new_h.sum() + new_c.sum()).backward() torch.testing.assert_allclose(new_h, expected_h.to(device), rtol=0.0001, atol=1e-04) - torch.testing.assert_allclose(new_C, expected_C.to(device), rtol=0.0001, atol=1e-04) + torch.testing.assert_allclose(new_c, expected_c.to(device), rtol=0.0001, atol=1e-04) if __name__ == "__main__": From 48fcb77fc7aac161396db911fdaa1b555bb90991 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 28 Jul 2020 11:46:00 +0800 Subject: [PATCH 05/22] [DLMED] add doc-strings --- monai/networks/extensions/lltm/lltm.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/monai/networks/extensions/lltm/lltm.py b/monai/networks/extensions/lltm/lltm.py index c70ec47363..fb62eaeb03 100644 --- a/monai/networks/extensions/lltm/lltm.py +++ b/monai/networks/extensions/lltm/lltm.py @@ -44,7 +44,20 @@ def backward(ctx, grad_h, grad_cell): class LLTM(nn.Module): - def __init__(self, input_features, state_size): + """ + This recurrent unit is similar to an LSTM, but differs in that it lacks a forget + gate and uses an Exponential Linear Unit (ELU) as its internal activation function. + Because this unit never forgets, call it LLTM, or Long-Long-Term-Memory unit. + It has both C++ and CUDA implementation, automatically switch according to the + target device where put this module to. + + Args: + input_features: size of input feature data + state_size: size of the state of recurrent unit + + Referring to: https://pytorch.org/tutorials/advanced/cpp_extension.html + """ + def __init__(self, input_features: int, state_size: int): super(LLTM, self).__init__() self.input_features = input_features self.state_size = state_size From 0515c9f6bfa0086c725243630737dc15639480cf Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 28 Jul 2020 11:49:11 +0800 Subject: [PATCH 06/22] [DLMED] add to docs --- docs/source/networks.rst | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index d7f1c4e3dd..26b4e2fc1c 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -66,6 +66,17 @@ Blocks :members: +Extensions +---------- +.. automodule:: monai.networks.extensions +.. currentmodule:: monai.networks.extensions + +`LLTM` +~~~~~~~~~~~~~ +.. autoclass:: LLTM + :members: + + Layers ------ From d8070394b2964f5c77620425612cb9c181ccf23b Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 28 Jul 2020 06:09:29 +0000 Subject: [PATCH 07/22] [MONAI] python code formatting --- monai/networks/extensions/lltm/lltm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/networks/extensions/lltm/lltm.py b/monai/networks/extensions/lltm/lltm.py index fb62eaeb03..cdf67e136e 100644 --- a/monai/networks/extensions/lltm/lltm.py +++ b/monai/networks/extensions/lltm/lltm.py @@ -57,6 +57,7 @@ class LLTM(nn.Module): Referring to: https://pytorch.org/tutorials/advanced/cpp_extension.html """ + def __init__(self, input_features: int, state_size: int): super(LLTM, self).__init__() self.input_features = input_features From c4bad5b9811aa5d0fd845033fda8f68ff04b9eb8 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 28 Jul 2020 14:52:18 +0800 Subject: [PATCH 08/22] [DLMED] fix mypy error --- monai/networks/extensions/lltm/lltm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/extensions/lltm/lltm.py b/monai/networks/extensions/lltm/lltm.py index cdf67e136e..cbb8699182 100644 --- a/monai/networks/extensions/lltm/lltm.py +++ b/monai/networks/extensions/lltm/lltm.py @@ -62,8 +62,8 @@ def __init__(self, input_features: int, state_size: int): super(LLTM, self).__init__() self.input_features = input_features self.state_size = state_size - self.weights = nn.Parameter(torch.Tensor(3 * state_size, input_features + state_size)) - self.bias = nn.Parameter(torch.Tensor(1, 3 * state_size)) + self.weights = nn.Parameter(torch.empty(3 * state_size, input_features + state_size)) + self.bias = nn.Parameter(torch.empty(1, 3 * state_size)) self.reset_parameters() def reset_parameters(self): From 311e5f234c968cc41242aa575981fbcc383263e9 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 28 Jul 2020 22:01:16 +0800 Subject: [PATCH 09/22] [DLMED] update according to comments --- .github/workflows/pythonapp.yml | 5 +- .github/workflows/setupapp.yml | 3 + docs/source/networks.rst | 2 +- monai/networks/extensions/__init__.py | 12 --- monai/networks/extensions/lltm/__init__.py | 10 --- monai/networks/extensions/lltm/lltm.cpp | 4 +- monai/networks/extensions/lltm/lltm.py | 75 ------------------- monai/networks/extensions/lltm/lltm_cuda.cpp | 4 +- monai/networks/layers/simplelayers.py | 67 ++++++++++++++++- setup.py | 26 ++++++- setup_cpp.py | 20 ----- setup_cuda.py | 24 ------ tests/testing_data/._Task04_Hippocampus | Bin 120 -> 0 bytes 13 files changed, 97 insertions(+), 155 deletions(-) delete mode 100644 monai/networks/extensions/__init__.py delete mode 100644 monai/networks/extensions/lltm/__init__.py delete mode 100644 monai/networks/extensions/lltm/lltm.py delete mode 100644 setup_cpp.py delete mode 100644 setup_cuda.py delete mode 100755 tests/testing_data/._Task04_Hippocampus diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index eb4338be57..79a2fa7267 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -63,7 +63,7 @@ jobs: cat "requirements-dev.txt" python -m pip install -r requirements-dev.txt python -m pip list - python setup_cpp.py install + python setup.py install - name: Run quick tests (CPU ${{ runner.os }}) run: | python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' @@ -85,8 +85,7 @@ jobs: python -m pip uninstall -y torch torchvision python -m pip install torch==1.4 python -m pip install -r requirements-dev.txt - python setup_cpp.py install - python setup_cuda.py install + python setup.py install - name: Run quick tests (GPU) run: | python -m pip list diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index fed6de46a7..3050ee0e32 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -22,6 +22,7 @@ jobs: python -m pip install torch==1.4 python -m pip install -r requirements-dev.txt python -m pip list + python setup.py install - name: Run unit tests report coverage run: | nvidia-smi @@ -55,6 +56,7 @@ jobs: python -m pip install --upgrade pip wheel python -m pip install torch==1.4 python -m pip install -r requirements-dev.txt + python setup.py install - name: Run quick tests CPU ubuntu run: | python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' @@ -95,6 +97,7 @@ jobs: cat "requirements-dev.txt" python -m pip install -r requirements-dev.txt python -m pip list + python setup.py install - name: Run quick tests (CPU ${{ runner.os }}) run: | python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 26b4e2fc1c..32c7ee0caf 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -72,7 +72,7 @@ Extensions .. currentmodule:: monai.networks.extensions `LLTM` -~~~~~~~~~~~~~ +~~~~~~ .. autoclass:: LLTM :members: diff --git a/monai/networks/extensions/__init__.py b/monai/networks/extensions/__init__.py deleted file mode 100644 index 899ced6864..0000000000 --- a/monai/networks/extensions/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .lltm.lltm import LLTM diff --git a/monai/networks/extensions/lltm/__init__.py b/monai/networks/extensions/lltm/__init__.py deleted file mode 100644 index d0044e3563..0000000000 --- a/monai/networks/extensions/lltm/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/monai/networks/extensions/lltm/lltm.cpp b/monai/networks/extensions/lltm/lltm.cpp index 5d098b461c..c8e9077838 100644 --- a/monai/networks/extensions/lltm/lltm.cpp +++ b/monai/networks/extensions/lltm/lltm.cpp @@ -98,6 +98,6 @@ std::vector lltm_backward( } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &lltm_forward, "LLTM forward"); - m.def("backward", &lltm_backward, "LLTM backward"); + m.def("lltm_forward", &lltm_forward, "LLTM forward"); + m.def("lltm_backward", &lltm_backward, "LLTM backward"); } diff --git a/monai/networks/extensions/lltm/lltm.py b/monai/networks/extensions/lltm/lltm.py deleted file mode 100644 index cbb8699182..0000000000 --- a/monai/networks/extensions/lltm/lltm.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from torch import nn -from torch.autograd import Function -import torch - -from monai.utils import optional_import - -lltm_cpp, _ = optional_import("lltm_cpp") -lltm_cuda, _ = optional_import("lltm_cuda") - - -class LLTMFunction(Function): - @staticmethod - def forward(ctx, input, weights, bias, old_h, old_cell): - ext = lltm_cuda if weights.is_cuda else lltm_cpp - outputs = ext.forward(input, weights, bias, old_h, old_cell) - new_h, new_cell = outputs[:2] - variables = outputs[1:] + [weights] - ctx.save_for_backward(*variables) - - return new_h, new_cell - - @staticmethod - def backward(ctx, grad_h, grad_cell): - if grad_h.is_cuda: - outputs = lltm_cuda.backward(grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_tensors) - d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates = outputs - else: - outputs = lltm_cpp.backward(grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_tensors) - d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs - - return d_input, d_weights, d_bias, d_old_h, d_old_cell - - -class LLTM(nn.Module): - """ - This recurrent unit is similar to an LSTM, but differs in that it lacks a forget - gate and uses an Exponential Linear Unit (ELU) as its internal activation function. - Because this unit never forgets, call it LLTM, or Long-Long-Term-Memory unit. - It has both C++ and CUDA implementation, automatically switch according to the - target device where put this module to. - - Args: - input_features: size of input feature data - state_size: size of the state of recurrent unit - - Referring to: https://pytorch.org/tutorials/advanced/cpp_extension.html - """ - - def __init__(self, input_features: int, state_size: int): - super(LLTM, self).__init__() - self.input_features = input_features - self.state_size = state_size - self.weights = nn.Parameter(torch.empty(3 * state_size, input_features + state_size)) - self.bias = nn.Parameter(torch.empty(1, 3 * state_size)) - self.reset_parameters() - - def reset_parameters(self): - stdv = 1.0 / math.sqrt(self.state_size) - for weight in self.parameters(): - weight.data.uniform_(-stdv, +stdv) - - def forward(self, input, state): - return LLTMFunction.apply(input, self.weights, self.bias, *state) diff --git a/monai/networks/extensions/lltm/lltm_cuda.cpp b/monai/networks/extensions/lltm/lltm_cuda.cpp index b5db580d8c..ed27c25835 100644 --- a/monai/networks/extensions/lltm/lltm_cuda.cpp +++ b/monai/networks/extensions/lltm/lltm_cuda.cpp @@ -89,6 +89,6 @@ std::vector lltm_backward( } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &lltm_forward, "LLTM forward (CUDA)"); - m.def("backward", &lltm_backward, "LLTM backward (CUDA)"); + m.def("lltm_forward", &lltm_forward, "LLTM forward (CUDA)"); + m.def("lltm_backward", &lltm_backward, "LLTM backward (CUDA)"); } diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index b1ed11eeff..5338f22655 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -10,15 +10,19 @@ # limitations under the License. from typing import Sequence, Union - +import math import torch -import torch.nn as nn +from torch import nn +from torch.autograd import Function import torch.nn.functional as F - from monai.networks.layers.convutils import gaussian_1d, same_padding from monai.utils import ensure_tuple_rep +from monai.utils import optional_import -__all__ = ["SkipConnection", "Flatten", "GaussianFilter"] +_C, _ = optional_import("monai._C") +_C_CUDA, _ = optional_import("monai._C_CUDA") + +__all__ = ["SkipConnection", "Flatten", "GaussianFilter", "LLTM"] class SkipConnection(nn.Module): @@ -113,3 +117,58 @@ def _conv(input_: torch.Tensor, d: int) -> torch.Tensor: return self.conv_n(input=_conv(input_, d - 1), weight=kernel, padding=padding, groups=chns) return _conv(x, sp_dim - 1) + + +class LLTMFunction(Function): + @staticmethod + def forward(ctx, input, weights, bias, old_h, old_cell): + ext = _C_CUDA if weights.is_cuda else _C + outputs = ext.lltm_forward(input, weights, bias, old_h, old_cell) + new_h, new_cell = outputs[:2] + variables = outputs[1:] + [weights] + ctx.save_for_backward(*variables) + + return new_h, new_cell + + @staticmethod + def backward(ctx, grad_h, grad_cell): + if grad_h.is_cuda: + outputs = _C_CUDA.lltm_backward(grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_tensors) + d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates = outputs + else: + outputs = _C.lltm_backward(grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_tensors) + d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs + + return d_input, d_weights, d_bias, d_old_h, d_old_cell + + +class LLTM(nn.Module): + """ + This recurrent unit is similar to an LSTM, but differs in that it lacks a forget + gate and uses an Exponential Linear Unit (ELU) as its internal activation function. + Because this unit never forgets, call it LLTM, or Long-Long-Term-Memory unit. + It has both C++ and CUDA implementation, automatically switch according to the + target device where put this module to. + + Args: + input_features: size of input feature data + state_size: size of the state of recurrent unit + + Referring to: https://pytorch.org/tutorials/advanced/cpp_extension.html + """ + + def __init__(self, input_features: int, state_size: int): + super(LLTM, self).__init__() + self.input_features = input_features + self.state_size = state_size + self.weights = nn.Parameter(torch.empty(3 * state_size, input_features + state_size)) + self.bias = nn.Parameter(torch.empty(1, 3 * state_size)) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1.0 / math.sqrt(self.state_size) + for weight in self.parameters(): + weight.data.uniform_(-stdv, +stdv) + + def forward(self, input, state): + return LLTMFunction.apply(input, self.weights, self.bias, *state) diff --git a/setup.py b/setup.py index 5158fa1fb9..baa35e6599 100644 --- a/setup.py +++ b/setup.py @@ -10,14 +10,36 @@ # limitations under the License. from setuptools import find_packages, setup - +import torch +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension import versioneer if __name__ == "__main__": + cmds = versioneer.get_cmdclass() + cmds.update({"build_ext": BuildExtension}) + ext_modules = [ + CppExtension( + "monai._C", + [ + "monai/networks/extensions/lltm/lltm.cpp" + ] + ) + ] + if torch.cuda.is_available(): + ext_modules.append( + CUDAExtension( + "monai._C_CUDA", + [ + "monai/networks/extensions/lltm/lltm_cuda.cpp", + "monai/networks/extensions/lltm/lltm_cuda_kernel.cu" + ], + ) + ) setup( version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), + cmdclass=cmds, packages=find_packages(exclude=("docs", "examples", "tests", "research")), zip_safe=False, package_data={"monai": ["py.typed"]}, + ext_modules=ext_modules, ) diff --git a/setup_cpp.py b/setup_cpp.py deleted file mode 100644 index d4cb3f539c..0000000000 --- a/setup_cpp.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CppExtension - -setup( - name="lltm_cpp", - ext_modules=[CppExtension("lltm_cpp", ["monai/networks/extensions/lltm/lltm.cpp"])], - cmdclass={"build_ext": BuildExtension}, -) diff --git a/setup_cuda.py b/setup_cuda.py deleted file mode 100644 index 83f7f09216..0000000000 --- a/setup_cuda.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -setup( - name="lltm_cuda", - ext_modules=[ - CUDAExtension( - "lltm_cuda", - ["monai/networks/extensions/lltm/lltm_cuda.cpp", "monai/networks/extensions/lltm/lltm_cuda_kernel.cu"], - ), - ], - cmdclass={"build_ext": BuildExtension}, -) diff --git a/tests/testing_data/._Task04_Hippocampus b/tests/testing_data/._Task04_Hippocampus deleted file mode 100755 index 7e298b58a4ae62237f53b2df6f7177530d27dd9d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 120 zcmZQz6=P>$Vqox1Ojhs@R)|o50+1L3ClDI}u^SMB_!U452P|+FI>Rv}BnT=7(t)B2 GrWOGICI*}U From 86bcdfaeb5ea9dad476660b59a1fee9dfc3593f8 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 28 Jul 2020 14:07:31 +0000 Subject: [PATCH 10/22] [MONAI] python code formatting --- setup.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/setup.py b/setup.py index baa35e6599..09915d8f8f 100644 --- a/setup.py +++ b/setup.py @@ -17,22 +17,12 @@ if __name__ == "__main__": cmds = versioneer.get_cmdclass() cmds.update({"build_ext": BuildExtension}) - ext_modules = [ - CppExtension( - "monai._C", - [ - "monai/networks/extensions/lltm/lltm.cpp" - ] - ) - ] + ext_modules = [CppExtension("monai._C", ["monai/networks/extensions/lltm/lltm.cpp"])] if torch.cuda.is_available(): ext_modules.append( CUDAExtension( "monai._C_CUDA", - [ - "monai/networks/extensions/lltm/lltm_cuda.cpp", - "monai/networks/extensions/lltm/lltm_cuda_kernel.cu" - ], + ["monai/networks/extensions/lltm/lltm_cuda.cpp", "monai/networks/extensions/lltm/lltm_cuda_kernel.cu"], ) ) setup( From 135e3003d3884b3799ea88701893f8beee3ecbbf Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 28 Jul 2020 22:14:17 +0800 Subject: [PATCH 11/22] [DLMED] fix package --- .github/workflows/pythonapp.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 79a2fa7267..8e83bb54c4 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -112,9 +112,10 @@ jobs: uses: actions/setup-python@v1 with: python-version: '3.x' - - name: Install setuptools + - name: Install dependencies run: | python -m pip install --user --upgrade pip setuptools wheel twine + python -m pip install torch==1.4 - name: Test source archive and wheel file run: | git fetch --depth=1 origin +refs/tags/*:refs/tags/* From 29f066712e1de2145809cebf670b4c8e4084f091 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 28 Jul 2020 22:40:17 +0800 Subject: [PATCH 12/22] [DLMED] fix test errors --- .github/workflows/pythonapp.yml | 1 + docs/source/networks.rst | 16 +++++----------- tests/test_lltm.py | 2 +- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 8e83bb54c4..566b8061c9 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -116,6 +116,7 @@ jobs: run: | python -m pip install --user --upgrade pip setuptools wheel twine python -m pip install torch==1.4 + python -m pip install -r requirements.txt - name: Test source archive and wheel file run: | git fetch --depth=1 origin +refs/tags/*:refs/tags/* diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 32c7ee0caf..7fa23a4f64 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -66,17 +66,6 @@ Blocks :members: -Extensions ----------- -.. automodule:: monai.networks.extensions -.. currentmodule:: monai.networks.extensions - -`LLTM` -~~~~~~ -.. autoclass:: LLTM - :members: - - Layers ------ @@ -140,6 +129,11 @@ Layers .. autoclass:: monai.networks.layers.AffineTransform :members: +`LLTM` +~~~~~~ +.. autoclass:: LLTM + :members: + `Utilities` ~~~~~~~~~~~ .. automodule:: monai.networks.layers.convutils diff --git a/tests/test_lltm.py b/tests/test_lltm.py index 9e65db77fe..5c666a8794 100644 --- a/tests/test_lltm.py +++ b/tests/test_lltm.py @@ -14,7 +14,7 @@ import torch from parameterized import parameterized -from monai.networks.extensions import LLTM +from monai.networks.layers import LLTM TEST_CASE_1 = [ {"input_features": 32, "state_size": 2}, From e3347dcd70eb6f2a4c7e3f5e8748813c6b58566f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 28 Jul 2020 18:00:46 +0100 Subject: [PATCH 13/22] fixes install errors --- .github/workflows/pythonapp.yml | 6 +++--- .github/workflows/setupapp.yml | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 566b8061c9..1d5eaa02c9 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -63,7 +63,7 @@ jobs: cat "requirements-dev.txt" python -m pip install -r requirements-dev.txt python -m pip list - python setup.py install + python -m pip install -e . --no-use-pep517 - name: Run quick tests (CPU ${{ runner.os }}) run: | python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' @@ -85,7 +85,7 @@ jobs: python -m pip uninstall -y torch torchvision python -m pip install torch==1.4 python -m pip install -r requirements-dev.txt - python setup.py install + python -m pip install -e . --no-use-pep517 - name: Run quick tests (GPU) run: | python -m pip list @@ -144,7 +144,7 @@ jobs: rm monai*.whl # install from tar.gz - python -m pip install monai*.tar.gz + python -m pip install monai*.tar.gz --no-use-pep517 python -c 'import monai; monai.config.print_config()' 2>&1 | grep -iv "unknown" python -c 'import monai; print(monai.__file__)' python -m pip uninstall -y monai diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index 3050ee0e32..1ed76c8286 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -22,7 +22,7 @@ jobs: python -m pip install torch==1.4 python -m pip install -r requirements-dev.txt python -m pip list - python setup.py install + python -m pip install -e . --no-use-pep517 - name: Run unit tests report coverage run: | nvidia-smi @@ -56,7 +56,7 @@ jobs: python -m pip install --upgrade pip wheel python -m pip install torch==1.4 python -m pip install -r requirements-dev.txt - python setup.py install + python -m pip install -e . --no-use-pep517 - name: Run quick tests CPU ubuntu run: | python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' @@ -97,7 +97,7 @@ jobs: cat "requirements-dev.txt" python -m pip install -r requirements-dev.txt python -m pip list - python setup.py install + python -m pip install -e . --no-use-pep517 - name: Run quick tests (CPU ${{ runner.os }}) run: | python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' From 54d615c0482795484c8e0bfd0320738d9487d2bd Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 29 Jul 2020 15:42:29 +0100 Subject: [PATCH 14/22] test pep517 install --- .github/workflows/pythonapp.yml | 8 ++++---- .github/workflows/setupapp.yml | 6 +++--- pyproject.toml | 8 ++++++++ setup.cfg | 3 +++ 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 1d5eaa02c9..5be6ade32b 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -63,7 +63,7 @@ jobs: cat "requirements-dev.txt" python -m pip install -r requirements-dev.txt python -m pip list - python -m pip install -e . --no-use-pep517 + python setup.py develop - name: Run quick tests (CPU ${{ runner.os }}) run: | python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' @@ -85,7 +85,7 @@ jobs: python -m pip uninstall -y torch torchvision python -m pip install torch==1.4 python -m pip install -r requirements-dev.txt - python -m pip install -e . --no-use-pep517 + python setup.py develop - name: Run quick tests (GPU) run: | python -m pip list @@ -115,7 +115,7 @@ jobs: - name: Install dependencies run: | python -m pip install --user --upgrade pip setuptools wheel twine - python -m pip install torch==1.4 + python -m pip install torch>=1.4 python -m pip install -r requirements.txt - name: Test source archive and wheel file run: | @@ -144,7 +144,7 @@ jobs: rm monai*.whl # install from tar.gz - python -m pip install monai*.tar.gz --no-use-pep517 + python -m pip install monai*.tar.gz python -c 'import monai; monai.config.print_config()' 2>&1 | grep -iv "unknown" python -c 'import monai; print(monai.__file__)' python -m pip uninstall -y monai diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index 1ed76c8286..effc65c9b6 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -22,7 +22,7 @@ jobs: python -m pip install torch==1.4 python -m pip install -r requirements-dev.txt python -m pip list - python -m pip install -e . --no-use-pep517 + python setup.py develop - name: Run unit tests report coverage run: | nvidia-smi @@ -56,7 +56,7 @@ jobs: python -m pip install --upgrade pip wheel python -m pip install torch==1.4 python -m pip install -r requirements-dev.txt - python -m pip install -e . --no-use-pep517 + python setup.py develop - name: Run quick tests CPU ubuntu run: | python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' @@ -97,7 +97,7 @@ jobs: cat "requirements-dev.txt" python -m pip install -r requirements-dev.txt python -m pip list - python -m pip install -e . --no-use-pep517 + python setup.py develop - name: Run quick tests (CPU ${{ runner.os }}) run: | python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' diff --git a/pyproject.toml b/pyproject.toml index 36ececc145..9c49e77242 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,11 @@ +[build-system] +requires = [ + "wheel", + "setuptools", + "torch>=1.4", + "ninja", +] + [tool.black] line-length = 120 target-version = ['py36', 'py37', 'py38'] diff --git a/setup.cfg b/setup.cfg index dee943d076..0bd4133310 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,6 +11,9 @@ license = Apache License 2.0 [options] python_requires = >= 3.6 +setup_requires = + torch>=1.4 + ninja install_requires = torch>=1.4 numpy>=1.17 From 6328ad8a97e15ef46bedda33e73f16e07156a632 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 30 Jul 2020 09:09:35 +0100 Subject: [PATCH 15/22] full tests --- .github/workflows/cron.yml | 2 ++ .github/workflows/pythonapp.yml | 1 - .github/workflows/setupapp.yml | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index 2c7b1ba971..7d8e0daf0f 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -27,6 +27,7 @@ jobs: fi python -m pip install -r requirements-dev.txt python -m pip list + python setup.py develop - name: Run tests report coverage run: | nvidia-smi @@ -57,6 +58,7 @@ jobs: python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))" python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))' ngc --version + python setup.py develop ./runtests.sh --coverage --pytype coverage xml - name: Upload coverage diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 5be6ade32b..c62eb7fbdf 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -116,7 +116,6 @@ jobs: run: | python -m pip install --user --upgrade pip setuptools wheel twine python -m pip install torch>=1.4 - python -m pip install -r requirements.txt - name: Test source archive and wheel file run: | git fetch --depth=1 origin +refs/tags/*:refs/tags/* diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index effc65c9b6..c53204c870 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -5,6 +5,7 @@ on: push: branches: - master + pull_request: jobs: coverage-py3: From 8d3d9381ad84e53ad26d69b62d55aceac416950d Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 30 Jul 2020 12:02:39 +0100 Subject: [PATCH 16/22] adds local testing build commands --- runtests.sh | 60 ++++++++++++++++++++++++----------------------------- 1 file changed, 27 insertions(+), 33 deletions(-) diff --git a/runtests.sh b/runtests.sh index d7657d0ee1..c9edce31e3 100755 --- a/runtests.sh +++ b/runtests.sh @@ -45,10 +45,11 @@ function print_usage { echo "MONAI unit testing utilities." echo "" echo "Examples:" - echo "./runtests.sh --codeformat --coverage # runs full tests (${green}recommended before making pull requests${noColor})." - echo "./runtests.sh --codeformat --nounittests # runs coding style and static type checking." - echo "./runtests.sh --quick # runs minimal unit tests, for quick verification during code developments." - echo "./runtests.sh --black-fix # runs automatic code formatting using \"black\"." + echo "./runtests.sh --codeformat --coverage # run full tests (${green}recommended before making pull requests${noColor})." + echo "./runtests.sh --codeformat --nounittests # run coding style and static type checking." + echo "./runtests.sh --quick # run minimal unit tests, for quick verification during code developments." + echo "./runtests.sh --black-fix # run automatic code formatting using \"black\"." + echo "./runtests.sh --clean # clean up temporary files and run \"python setup.py develop --uninstall\"." echo "" echo "Code style check options:" echo " --black : perform \"black\" code format checks" @@ -85,11 +86,26 @@ function print_version { ${cmdPrefix}python -c 'import monai; monai.config.print_config()' } -function install_deps { +function install_compile_deps { echo "Pip installing MONAI development dependencies..." + ${cmdPrefix}python setup.py -v develop ${cmdPrefix}pip install -r requirements-dev.txt } +function clean_py() { + # uninstall the development package + ${cmdPrefix}python setup.py -v develop --uninstall + + # remove temporary files + TO_CLEAN=${*:-'.'} + find ${TO_CLEAN} -type f -name "*.py[co]" -delete + find ${TO_CLEAN} -type f -name ".coverage" -delete + find ${TO_CLEAN} -type d -name "__pycache__" -delete + find ${TO_CLEAN} -depth -type d -name ".mypy_cache" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -type d -name ".pytype" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -type d -name ".coverage" -exec rm -r "{}" + +} + function torch_validate { ${cmdPrefix}python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' } @@ -205,29 +221,7 @@ if [ $doCleanup = true ] then echo "${separator}${blue}clean${noColor}" - if [ -d .mypy_cache ] - then - ${cmdPrefix}rm -r .mypy_cache - elif [ -f .mypy_cache ] - then - ${cmdPrefix}rm .mypy_cache - fi - - if [ -d .pytype ] - then - ${cmdPrefix}rm -r .pytype - elif [ -f .pytype ] - then - ${cmdPrefix}rm .pytype - fi - - if [ -d .coverage ] - then - ${cmdPrefix}rm -r .coverage - elif [ -f .coverage ] - then - ${cmdPrefix}rm .coverage - fi + clean_py echo "${green}done!${noColor}" exit @@ -247,7 +241,7 @@ then # ensure that the necessary packages for code format testing are installed if [[ ! -f "$(which black)" ]] then - install_deps + install_compile_deps fi ${cmdPrefix}black --version @@ -283,7 +277,7 @@ then # ensure that the necessary packages for code format testing are installed if [[ ! -f "$(which isort)" ]] then - install_deps + install_compile_deps fi ${cmdPrefix}isort --version @@ -314,7 +308,7 @@ then # ensure that the necessary packages for code format testing are installed if [[ ! -f "$(which flake8)" ]] then - install_deps + install_compile_deps fi ${cmdPrefix}flake8 --version @@ -340,7 +334,7 @@ then # ensure that the necessary packages for code format testing are installed if [[ ! -f "$(which pytype)" ]] then - install_deps + install_compile_deps fi ${cmdPrefix}pytype --version @@ -366,7 +360,7 @@ then # ensure that the necessary packages for code format testing are installed if [[ ! -f "$(which mypy)" ]] then - install_deps + install_compile_deps fi ${cmdPrefix}mypy --version From 20ee5bc89b3b87ebf2a24a28f030829bf3f7411b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 30 Jul 2020 12:29:34 +0100 Subject: [PATCH 17/22] cleanup setup.py --- requirements-dev.txt | 1 + setup.py | 50 +++++++++++++++++++++++++++++++------------- 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 50dcde1ee3..36be129174 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -22,3 +22,4 @@ black isort pytype>=2020.6.1 mypy +ninja diff --git a/setup.py b/setup.py index 09915d8f8f..7c1f2a03ee 100644 --- a/setup.py +++ b/setup.py @@ -9,27 +9,49 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings + from setuptools import find_packages, setup -import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension + import versioneer -if __name__ == "__main__": - cmds = versioneer.get_cmdclass() - cmds.update({"build_ext": BuildExtension}) + +def get_extensions(): + + try: + import torch + from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension + except ImportError: + warnings.warn("torch cpp/cuda building skipped.") + return [] + ext_modules = [CppExtension("monai._C", ["monai/networks/extensions/lltm/lltm.cpp"])] - if torch.cuda.is_available(): + if torch.cuda.is_available() and (CUDA_HOME is not None): ext_modules.append( CUDAExtension( "monai._C_CUDA", ["monai/networks/extensions/lltm/lltm_cuda.cpp", "monai/networks/extensions/lltm/lltm_cuda_kernel.cu"], ) ) - setup( - version=versioneer.get_version(), - cmdclass=cmds, - packages=find_packages(exclude=("docs", "examples", "tests", "research")), - zip_safe=False, - package_data={"monai": ["py.typed"]}, - ext_modules=ext_modules, - ) + return ext_modules + + +def get_cmds(): + cmds = versioneer.get_cmdclass() + try: + from torch.utils.cpp_extension import BuildExtension + + cmds.update({"build_ext": BuildExtension}) + except ImportError: + warnings.warn("torch cpp_extension module not found.") + return cmds + + +setup( + version=versioneer.get_version(), + cmdclass=get_cmds(), + packages=find_packages(exclude=("docs", "examples", "tests", "research")), + zip_safe=False, + package_data={"monai": ["py.typed"]}, + ext_modules=get_extensions(), +) From cd560d6cb24f7ed3e1e29cac2548f647e732eec5 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 30 Jul 2020 13:27:42 +0100 Subject: [PATCH 18/22] cleanup ci scripts --- .github/workflows/cron.yml | 2 -- .github/workflows/pythonapp.yml | 3 +-- .github/workflows/setupapp.yml | 2 -- monai/__init__.py | 4 +++- runtests.sh | 33 ++++++++++++++++++++++++--------- setup.cfg | 10 +++++++--- 6 files changed, 35 insertions(+), 19 deletions(-) diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index 7d8e0daf0f..2c7b1ba971 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -27,7 +27,6 @@ jobs: fi python -m pip install -r requirements-dev.txt python -m pip list - python setup.py develop - name: Run tests report coverage run: | nvidia-smi @@ -58,7 +57,6 @@ jobs: python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))" python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))' ngc --version - python setup.py develop ./runtests.sh --coverage --pytype coverage xml - name: Upload coverage diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index c62eb7fbdf..14d55fcfa9 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -63,7 +63,7 @@ jobs: cat "requirements-dev.txt" python -m pip install -r requirements-dev.txt python -m pip list - python setup.py develop + python setup.py develop # compile the cpp extensions - name: Run quick tests (CPU ${{ runner.os }}) run: | python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' @@ -85,7 +85,6 @@ jobs: python -m pip uninstall -y torch torchvision python -m pip install torch==1.4 python -m pip install -r requirements-dev.txt - python setup.py develop - name: Run quick tests (GPU) run: | python -m pip list diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index c53204c870..916c9220dc 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -23,7 +23,6 @@ jobs: python -m pip install torch==1.4 python -m pip install -r requirements-dev.txt python -m pip list - python setup.py develop - name: Run unit tests report coverage run: | nvidia-smi @@ -57,7 +56,6 @@ jobs: python -m pip install --upgrade pip wheel python -m pip install torch==1.4 python -m pip install -r requirements-dev.txt - python setup.py develop - name: Run quick tests CPU ubuntu run: | python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' diff --git a/monai/__init__.py b/monai/__init__.py index a179c73532..f244d07037 100644 --- a/monai/__init__.py +++ b/monai/__init__.py @@ -22,7 +22,9 @@ __basedir__ = os.path.dirname(__file__) -excludes = "^(handlers)" # the handlers have some external decorators the users may not have installed +# handlers_* have some external decorators the users may not have installed +# *.so files and folder "_C" may not exists when the cpp extensions are not compiled +excludes = "(^(handlers))|((\\.so)$)|(_C)" # load directory modules only, skip loading individual files load_submodules(sys.modules[__name__], False, exclude_pattern=excludes) diff --git a/runtests.sh b/runtests.sh index c9edce31e3..e4926a8f72 100755 --- a/runtests.sh +++ b/runtests.sh @@ -86,12 +86,22 @@ function print_version { ${cmdPrefix}python -c 'import monai; monai.config.print_config()' } -function install_compile_deps { - echo "Pip installing MONAI development dependencies..." - ${cmdPrefix}python setup.py -v develop +function install_deps { + echo "Pip installing MONAI development dependencies and compile MONAI cpp extensions..." ${cmdPrefix}pip install -r requirements-dev.txt } +function compile_cpp { + echo "Compiling and installing MONAI cpp extensions..." + ${cmdPrefix}python setup.py -v develop --uninstall + if [[ "$OSTYPE" == "darwin"* ]]; + then # clang for mac os + CC=clang CXX=clang++ ${cmdPrefix}python setup.py -v develop + else + ${cmdPrefix}python setup.py -v develop + fi +} + function clean_py() { # uninstall the development package ${cmdPrefix}python setup.py -v develop --uninstall @@ -101,6 +111,10 @@ function clean_py() { find ${TO_CLEAN} -type f -name "*.py[co]" -delete find ${TO_CLEAN} -type f -name ".coverage" -delete find ${TO_CLEAN} -type d -name "__pycache__" -delete + + find ${TO_CLEAN} -depth -type d -name ".eggs" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -type d -name "monai.egg-info" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -type d -name "build" -exec rm -r "{}" + find ${TO_CLEAN} -depth -type d -name ".mypy_cache" -exec rm -r "{}" + find ${TO_CLEAN} -depth -type d -name ".pytype" -exec rm -r "{}" + find ${TO_CLEAN} -depth -type d -name ".coverage" -exec rm -r "{}" + @@ -216,7 +230,6 @@ fi # unconditionally report on the state of monai print_version - if [ $doCleanup = true ] then echo "${separator}${blue}clean${noColor}" @@ -227,6 +240,8 @@ then exit fi +# try to compile MONAI cpp +compile_cpp if [ $doBlackFormat = true ] then @@ -241,7 +256,7 @@ then # ensure that the necessary packages for code format testing are installed if [[ ! -f "$(which black)" ]] then - install_compile_deps + install_deps fi ${cmdPrefix}black --version @@ -277,7 +292,7 @@ then # ensure that the necessary packages for code format testing are installed if [[ ! -f "$(which isort)" ]] then - install_compile_deps + install_deps fi ${cmdPrefix}isort --version @@ -308,7 +323,7 @@ then # ensure that the necessary packages for code format testing are installed if [[ ! -f "$(which flake8)" ]] then - install_compile_deps + install_deps fi ${cmdPrefix}flake8 --version @@ -334,7 +349,7 @@ then # ensure that the necessary packages for code format testing are installed if [[ ! -f "$(which pytype)" ]] then - install_compile_deps + install_deps fi ${cmdPrefix}pytype --version @@ -360,7 +375,7 @@ then # ensure that the necessary packages for code format testing are installed if [[ ! -f "$(which mypy)" ]] then - install_compile_deps + install_deps fi ${cmdPrefix}mypy --version diff --git a/setup.cfg b/setup.cfg index 0bd4133310..b724cccb38 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,13 +51,13 @@ ignore = # N812 lowercase 'torch.nn.functional' imported as non lowercase 'F' N812 per-file-ignores = __init__.py: F401 -exclude = *.pyi,.git,monai/_version.py,versioneer.py,venv, _version.py +exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv, _version.py [isort] known_first_party = monai profile = black line_length = 120 -skip = .git, venv, versioneer.py, _version.py, conf.py +skip = .git, .eggs, venv, versioneer.py, _version.py, conf.py skip_glob = *.pyi [versioneer] @@ -111,11 +111,15 @@ ignore_errors = True # Always ignore any type issues in the monai/._version.py file ignore_errors = True +[mypy-monai.eggs] +# Always ignore any type issues in the monai/.eggs file +ignore_errors = True + [pytype] # NOTE: All relative paths are relative to the location of this file. # Space-separated list of files or directories to exclude. # i.e. exclude = **/*_test.py **/test_*.py -exclude = **/versioneer.py _version.py +exclude = **/versioneer.py _version.py .eggs # Space-separated list of files or directories to process. inputs = monai # Keep going past errors to analyze as many files as possible. From 68610653e2477852db4ec297e3e38b62ee56004f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 30 Jul 2020 14:00:56 +0100 Subject: [PATCH 19/22] update installation steps --- docs/source/installation.md | 9 ++++++--- runtests.sh | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/docs/source/installation.md b/docs/source/installation.md index 9f464a6863..1148d7a7c6 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -54,10 +54,13 @@ This command will create a ``MONAI/`` folder in your current directory. You can install it by running: ```bash cd MONAI/ -pip install -e . +python setup.py develop + +# to uninstall the package please run: +python setup.py develop --uninstall ``` or simply adding the root directory of the cloned source code (e.g., ``/workspace/Documents/MONAI``) to your ``$PYTHONPATH`` -and the codebase is ready to use. +and the codebase is ready to use (without the additional features of MONAI C++/CUDA extensions). ## Validating the install @@ -128,7 +131,7 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is ``` [nibabel, skimage, pillow, tensorboard, gdown, ignite] ``` -which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, +which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, `gdown`, and `pytorch-ignite` respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/runtests.sh b/runtests.sh index e4926a8f72..298b82312b 100755 --- a/runtests.sh +++ b/runtests.sh @@ -104,9 +104,11 @@ function compile_cpp { function clean_py() { # uninstall the development package + echo "Uninstalling MONAI development files..." ${cmdPrefix}python setup.py -v develop --uninstall # remove temporary files + echo "Removing temporary files..." TO_CLEAN=${*:-'.'} find ${TO_CLEAN} -type f -name "*.py[co]" -delete find ${TO_CLEAN} -type f -name ".coverage" -delete @@ -115,6 +117,7 @@ function clean_py() { find ${TO_CLEAN} -depth -type d -name ".eggs" -exec rm -r "{}" + find ${TO_CLEAN} -depth -type d -name "monai.egg-info" -exec rm -r "{}" + find ${TO_CLEAN} -depth -type d -name "build" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -type d -name "dist" -exec rm -r "{}" + find ${TO_CLEAN} -depth -type d -name ".mypy_cache" -exec rm -r "{}" + find ${TO_CLEAN} -depth -type d -name ".pytype" -exec rm -r "{}" + find ${TO_CLEAN} -depth -type d -name ".coverage" -exec rm -r "{}" + @@ -227,9 +230,6 @@ then function dryrun { echo " " "$@"; } fi -# unconditionally report on the state of monai -print_version - if [ $doCleanup = true ] then echo "${separator}${blue}clean${noColor}" @@ -243,6 +243,9 @@ fi # try to compile MONAI cpp compile_cpp +# unconditionally report on the state of monai +print_version + if [ $doBlackFormat = true ] then set +e # disable exit on failure so that diagnostics can be given on failure From 07cfdec81685c9772a05f5a3435fb7694b6091ad Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 30 Jul 2020 14:06:15 +0100 Subject: [PATCH 20/22] remove temp full tests --- .github/workflows/setupapp.yml | 1 - monai/__init__.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index 916c9220dc..42e5dd7b1b 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -5,7 +5,6 @@ on: push: branches: - master - pull_request: jobs: coverage-py3: diff --git a/monai/__init__.py b/monai/__init__.py index f244d07037..3bbb327206 100644 --- a/monai/__init__.py +++ b/monai/__init__.py @@ -23,7 +23,7 @@ __basedir__ = os.path.dirname(__file__) # handlers_* have some external decorators the users may not have installed -# *.so files and folder "_C" may not exists when the cpp extensions are not compiled +# *.so files and folder "_C" may not exist when the cpp extensions are not compiled excludes = "(^(handlers))|((\\.so)$)|(_C)" # load directory modules only, skip loading individual files From c642318cf0ee0e26dd13f61cd154851dac367bf2 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 31 Jul 2020 11:58:03 +0100 Subject: [PATCH 21/22] more comments on the setup --- .github/workflows/pythonapp.yml | 3 +++ .github/workflows/setupapp.yml | 2 +- setup.cfg | 2 +- setup.py | 2 ++ 4 files changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 14d55fcfa9..e36e76f843 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -114,6 +114,9 @@ jobs: - name: Install dependencies run: | python -m pip install --user --upgrade pip setuptools wheel twine + # install the latest pytorch for testing + # however, "pip install monai*.tar.gz" will build cpp/cuda with an isolated + # fresh torch installation according to pyproject.toml python -m pip install torch>=1.4 - name: Test source archive and wheel file run: | diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index 42e5dd7b1b..75cf89b5f3 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -95,7 +95,7 @@ jobs: cat "requirements-dev.txt" python -m pip install -r requirements-dev.txt python -m pip list - python setup.py develop + python setup.py develop # to compile extensions using the system native torch - name: Run quick tests (CPU ${{ runner.os }}) run: | python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' diff --git a/setup.cfg b/setup.cfg index b724cccb38..24a7de1b14 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,7 +51,7 @@ ignore = # N812 lowercase 'torch.nn.functional' imported as non lowercase 'F' N812 per-file-ignores = __init__.py: F401 -exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv, _version.py +exclude = *.pyi,.git,.eggs,monai/_version.py,versioneer.py,venv,_version.py [isort] known_first_party = monai diff --git a/setup.py b/setup.py index 7c1f2a03ee..aaced090b0 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,8 @@ def get_extensions(): try: import torch from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension + + print(f"setup.py with torch {torch.__version__}") except ImportError: warnings.warn("torch cpp/cuda building skipped.") return [] From 4b205b2d4d5b8dcc4414ec55a281034258f0f61f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 31 Jul 2020 12:29:39 +0100 Subject: [PATCH 22/22] test cache action --- .github/workflows/pythonapp.yml | 61 +++++++++++++++++++++++++++++++-- .github/workflows/setupapp.yml | 49 +++++++++++++++++++++++++- runtests.sh | 2 +- 3 files changed, 107 insertions(+), 5 deletions(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index e36e76f843..a4295c013c 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -8,6 +8,10 @@ on: pull_request: jobs: + # caching of these jobs: + # - docker-20-03-py3-pip- (shared) + # - ubuntu py37 pip- + # - os-latest-pip- (shared) flake8-py3: runs-on: ubuntu-latest steps: @@ -16,10 +20,20 @@ jobs: uses: actions/setup-python@v1 with: python-version: 3.7 + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} - name: Install dependencies run: | python -m pip install --upgrade pip wheel - pip install -r requirements-dev.txt + python -m pip install -r requirements-dev.txt - name: Lint with black formater run: | $(pwd)/runtests.sh --nounittests --black @@ -51,6 +65,18 @@ jobs: run: | which python python -m pip install --upgrade pip wheel + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + echo "::set-output name=dir::$(pip cache dir)" + shell: bash + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ matrix.os }}-latest-pip-${{ steps.pip-cache.outputs.datew }} - if: runner.os == 'windows' name: Install torch cpu from pytorch.org (Windows only) run: | @@ -78,6 +104,16 @@ jobs: runs-on: [self-hosted, linux, x64] steps: - uses: actions/checkout@v2 + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ~/.cache/pip + key: docker-20-03-py3-pip-${{ steps.pip-cache.outputs.datew }} - name: Install the dependencies run: | which python @@ -94,7 +130,6 @@ jobs: python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))" python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))' ./runtests.sh --quick - pip list coverage xml - name: Upload coverage uses: codecov/codecov-action@v1 @@ -111,6 +146,16 @@ jobs: uses: actions/setup-python@v1 with: python-version: '3.x' + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} - name: Install dependencies run: | python -m pip install --user --upgrade pip setuptools wheel twine @@ -163,10 +208,20 @@ jobs: uses: actions/setup-python@v1 with: python-version: 3.7 + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} - name: Install dependencies run: | python -m pip install --upgrade pip wheel - pip install -r docs/requirements.txt + python -m pip install -r docs/requirements.txt - name: Make html run: | cd docs/ diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index 75cf89b5f3..3c2c3b4168 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -7,6 +7,10 @@ on: - master jobs: + # caching of these jobs: + # - docker-20-03-py3-pip- (shared) + # - ubuntu py36 37 38-pip- + # - os-latest-pip (shared) coverage-py3: container: image: nvcr.io/nvidia/pytorch:20.03-py3 @@ -14,6 +18,16 @@ jobs: runs-on: [self-hosted, linux, x64] steps: - uses: actions/checkout@v2 + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ~/.cache/pip + key: docker-20-03-py3-pip-${{ steps.pip-cache.outputs.datew }} - name: Install the dependencies run: | which python @@ -21,9 +35,9 @@ jobs: python -m pip uninstall -y torch torchvision python -m pip install torch==1.4 python -m pip install -r requirements-dev.txt - python -m pip list - name: Run unit tests report coverage run: | + python -m pip list nvidia-smi export CUDA_VISIBLE_DEVICES=$(python -m tests.utils) echo $CUDA_VISIBLE_DEVICES @@ -50,6 +64,16 @@ jobs: uses: actions/setup-python@v1 with: python-version: ${{ matrix.python-version }} + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ~/.cache/pip + key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ steps.pip-cache.outputs.datew }} - name: Install the dependencies run: | python -m pip install --upgrade pip wheel @@ -57,6 +81,7 @@ jobs: python -m pip install -r requirements-dev.txt - name: Run quick tests CPU ubuntu run: | + python -m pip list python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' ./runtests.sh --quick coverage xml @@ -83,6 +108,18 @@ jobs: run: | which python python -m pip install --upgrade pip wheel + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + echo "::set-output name=dir::$(pip cache dir)" + shell: bash + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ matrix.os }}-latest-pip-${{ steps.pip-cache.outputs.datew }} - if: runner.os == 'windows' name: Install torch cpu from pytorch.org (Windows only) run: | @@ -110,6 +147,16 @@ jobs: uses: actions/setup-python@v1 with: python-version: 3.7 + - name: cache weekly timestamp + id: pip-cache + run: | + echo "::set-output name=datew::$(date '+%Y-%V')" + - name: cache for pip + uses: actions/cache@v2 + id: cache + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} - name: Install the default branch run: | pip install git+https://github.com/Project-MONAI/MONAI#egg=MONAI diff --git a/runtests.sh b/runtests.sh index 298b82312b..4e848d7161 100755 --- a/runtests.sh +++ b/runtests.sh @@ -88,7 +88,7 @@ function print_version { function install_deps { echo "Pip installing MONAI development dependencies and compile MONAI cpp extensions..." - ${cmdPrefix}pip install -r requirements-dev.txt + ${cmdPrefix}python -m pip install -r requirements-dev.txt } function compile_cpp {