From 6c334ffe9a0fb951ea9f63161c70855e3676ef5b Mon Sep 17 00:00:00 2001
From: OeslleLucena
Date: Tue, 24 Jan 2023 17:00:01 +0000
Subject: [PATCH 01/12] README change
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index ed2e2fbc..183f1e50 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
# MONAI Generative Models
-Prototyping repo for generative models to be integrated into MONAI core.
+Prototyping repository for generative models to be integrated into MONAI core.
## Features
* Network architectures: Diffusion Model, Autoencoder-KL, VQ-VAE, (Multi-scale) Patch-GAN discriminator.
* Diffusion Model Schedulers: DDPM, DDIM, and PNDM.
From a6aeaf629414086655ed0b35c4e29e2622f51cdc Mon Sep 17 00:00:00 2001
From: Virginia Fernandez <61539159+virginiafdez@users.noreply.github.com>
Date: Wed, 1 Feb 2023 13:15:17 +0000
Subject: [PATCH 02/12] Fix scale_factor (#214)
Co-authored-by: virginiafdez
---
generative/inferers/inferer.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py
index 6927f5ca..f65bdb20 100644
--- a/generative/inferers/inferer.py
+++ b/generative/inferers/inferer.py
@@ -347,14 +347,14 @@ def sample(
latent = outputs
with torch.no_grad():
- image = autoencoder_model.decode_stage_2_outputs(latent) * self.scale_factor
+ image = autoencoder_model.decode_stage_2_outputs(latent / self.scale_factor)
if save_intermediates:
intermediates = []
for latent_intermediate in latent_intermediates:
with torch.no_grad():
intermediates.append(
- autoencoder_model.decode_stage_2_outputs(latent_intermediate) * self.scale_factor
+ autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor)
)
return image, intermediates
From 4d5dac16c6079d6a7ebbf1dccd5f7ea3e54ed375 Mon Sep 17 00:00:00 2001
From: Virginia Fernandez <61539159+virginiafdez@users.noreply.github.com>
Date: Wed, 1 Feb 2023 13:28:12 +0000
Subject: [PATCH 03/12] Add is_fake_3d setting condition to error (#215)
* Flag is_fake_3d has to be set to False if you want to use PerceptualLoss with 3D networks; otherwise, error happens. Modified the error in the __init__ to account for this flag setting.
* Flag is_fake_3d has to be set to False if you want to use PerceptualLoss with 3D networks; otherwise, error happens. Modified the error in the __init__ to account for this flag setting.
---------
Co-authored-by: virginiafdez
---
generative/losses/perceptual.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py
index 68b94d75..5a3640b8 100644
--- a/generative/losses/perceptual.py
+++ b/generative/losses/perceptual.py
@@ -45,8 +45,11 @@ def __init__(
if spatial_dims not in [2, 3]:
raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.")
- if spatial_dims == 2 and "medicalnet_" in network_type:
- raise ValueError("MedicalNet networks are only compatible with ``spatial_dims=3``.")
+ if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type:
+ raise ValueError(
+ "MedicalNet networks are only compatible with ``spatial_dims=3``."
+ "Argument is_fake_3d must be set to False."
+ )
self.spatial_dims = spatial_dims
if spatial_dims == 3 and is_fake_3d is False:
From 9ce4c1ad8d56b2018f40b139819fecc292aa1797 Mon Sep 17 00:00:00 2001
From: Walter Hugo Lopez Pinaya
Date: Sat, 4 Feb 2023 14:11:44 +0000
Subject: [PATCH 04/12] Add check length of num_channels and attention_levels
(#221)
Signed-off-by: Walter Hugo Lopez Pinaya
---
generative/networks/nets/diffusion_model_unet.py | 3 +++
tests/test_diffusion_model_unet.py | 14 ++++++++++++++
2 files changed, 17 insertions(+)
diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py
index 8d13b76f..76caa08b 100644
--- a/generative/networks/nets/diffusion_model_unet.py
+++ b/generative/networks/nets/diffusion_model_unet.py
@@ -1636,6 +1636,9 @@ def __init__(
if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):
raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups")
+ if len(num_channels) != len(attention_levels):
+ raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels")
+
if isinstance(num_head_channels, int):
num_head_channels = (num_head_channels,) * len(attention_levels)
diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py
index 7e116ea6..f34b4a3f 100644
--- a/tests/test_diffusion_model_unet.py
+++ b/tests/test_diffusion_model_unet.py
@@ -360,6 +360,20 @@ def test_conditioned_models_no_class_labels(self):
)
net.forward(x=torch.rand((1, 1, 16, 32)), timesteps=torch.randint(0, 1000, (1,)).long())
+ def test_model_num_channels_not_same_size_of_attention_levels(self):
+ with self.assertRaises(ValueError):
+ DiffusionModelUNet(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ num_channels=(8, 8, 8),
+ attention_levels=(False, False),
+ norm_num_groups=8,
+ num_head_channels=8,
+ num_class_embeds=2,
+ )
+
def test_script_unconditioned_2d_models(self):
net = DiffusionModelUNet(
spatial_dims=2,
From e1fea91bea25efe20436ab2aa393b6f6982e72ab Mon Sep 17 00:00:00 2001
From: Walter Hugo Lopez Pinaya
Date: Sat, 4 Feb 2023 16:39:15 +0000
Subject: [PATCH 05/12] Suspend CI (#224)
Signed-off-by: Walter Hugo Lopez Pinaya
---
.github/workflows/pythonapp-min.yml | 342 ++++++++++++++--------------
1 file changed, 171 insertions(+), 171 deletions(-)
diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml
index 0d629c29..557b5776 100644
--- a/.github/workflows/pythonapp-min.yml
+++ b/.github/workflows/pythonapp-min.yml
@@ -1,171 +1,171 @@
-# Jenkinsfile.monai-premerge
-name: premerge-min
-
-on:
- # quick tests for pull requests and the releasing branches
- push:
- branches:
- - main
- pull_request:
-
-concurrency:
- # automatically cancel the previously triggered workflows when there's a newer version
- group: build-min-${{ github.event.pull_request.number || github.ref }}
- cancel-in-progress: true
-
-jobs:
- # caching of these jobs:
- # - docker-py3-pip- (shared)
- # - ubuntu py37 pip-
- # - os-latest-pip- (shared)
- min-dep-os: # min dependencies installed tests for different OS
- runs-on: ${{ matrix.os }}
- strategy:
- fail-fast: false
- matrix:
- os: [windows-latest, macOS-latest, ubuntu-latest]
- timeout-minutes: 40
- steps:
- - uses: actions/checkout@v3
- - name: Set up Python 3.8
- uses: actions/setup-python@v4
- with:
- python-version: '3.8'
- - name: Prepare pip wheel
- run: |
- which python
- python -m pip install --upgrade pip wheel
- - name: cache weekly timestamp
- id: pip-cache
- run: |
- echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
- echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- shell: bash
- - name: cache for pip
- uses: actions/cache@v3
- id: cache
- with:
- path: ${{ steps.pip-cache.outputs.dir }}
- key: ${{ matrix.os }}-latest-pip-${{ steps.pip-cache.outputs.datew }}
- - if: runner.os == 'windows'
- name: Install torch cpu from pytorch.org (Windows only)
- run: |
- python -m pip install torch==1.13.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
- - name: Install the dependencies
- run: |
- # min. requirements
- python -m pip install torch==1.13.1
- python -m pip install -r requirements-min.txt
- python -m pip list
- BUILD_MONAI=0 python setup.py develop # no compile of extensions
- shell: bash
- - name: Run quick tests (CPU ${{ runner.os }})
- run: |
- python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
- python -c "import monai; monai.config.print_config()"
- ./runtests.sh --min
- shell: bash
- env:
- QUICKTEST: True
-
- min-dep-py3: # min dependencies installed tests for different python
- runs-on: ubuntu-latest
- strategy:
- fail-fast: false
- matrix:
- python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
- timeout-minutes: 40
- steps:
- - uses: actions/checkout@v3
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v4
- with:
- python-version: ${{ matrix.python-version }}
- - name: Prepare pip wheel
- run: |
- which python
- python -m pip install --user --upgrade pip setuptools wheel
- - name: cache weekly timestamp
- id: pip-cache
- run: |
- echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
- echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- shell: bash
- - name: cache for pip
- uses: actions/cache@v3
- id: cache
- with:
- path: ${{ steps.pip-cache.outputs.dir }}
- key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }}
- - name: Install the dependencies
- run: |
- # min. requirements
- python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
- python -m pip install -r requirements-min.txt
- python -m pip list
- BUILD_MONAI=0 python setup.py develop # no compile of extensions
- shell: bash
- - name: Run quick tests (CPU ${{ runner.os }})
- run: |
- python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
- python -c "import monai; monai.config.print_config()"
- ./runtests.sh --min
- env:
- QUICKTEST: True
-
- min-dep-pytorch: # min dependencies installed tests for different pytorch
- runs-on: ubuntu-latest
- strategy:
- fail-fast: false
- matrix:
- pytorch-version: ['1.8.2', '1.9.1', '1.10.2', '1.11.0', '1.12.1', 'latest']
- timeout-minutes: 40
- steps:
- - uses: actions/checkout@v3
- - name: Set up Python 3.8
- uses: actions/setup-python@v4
- with:
- python-version: '3.8'
- - name: Prepare pip wheel
- run: |
- which python
- python -m pip install --user --upgrade pip setuptools wheel
- - name: cache weekly timestamp
- id: pip-cache
- run: |
- echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
- echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- shell: bash
- - name: cache for pip
- uses: actions/cache@v3
- id: cache
- with:
- path: ${{ steps.pip-cache.outputs.dir }}
- key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }}
- - name: Install the dependencies
- run: |
- # min. requirements
- if [ ${{ matrix.pytorch-version }} == "latest" ]; then
- python -m pip install torch
- elif [ ${{ matrix.pytorch-version }} == "1.8.2" ]; then
- python -m pip install torch==1.8.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cpu
- elif [ ${{ matrix.pytorch-version }} == "1.9.1" ]; then
- python -m pip install torch==1.9.1
- elif [ ${{ matrix.pytorch-version }} == "1.10.2" ]; then
- python -m pip install torch==1.10.2
- elif [ ${{ matrix.pytorch-version }} == "1.11.0" ]; then
- python -m pip install torch==1.11.0
- elif [ ${{ matrix.pytorch-version }} == "1.12.1" ]; then
- python -m pip install torch==1.12.1
- fi
- python -m pip install -r requirements-min.txt
- python -m pip list
- BUILD_MONAI=0 python setup.py develop # no compile of extensions
- shell: bash
- - name: Run quick tests (pytorch ${{ matrix.pytorch-version }})
- run: |
- python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
- python -c "import monai; monai.config.print_config()"
- ./runtests.sh --min
- env:
- QUICKTEST: True
+## Jenkinsfile.monai-premerge
+#name: premerge-min
+#
+#on:
+# # quick tests for pull requests and the releasing branches
+# push:
+# branches:
+# - main
+# pull_request:
+#
+#concurrency:
+# # automatically cancel the previously triggered workflows when there's a newer version
+# group: build-min-${{ github.event.pull_request.number || github.ref }}
+# cancel-in-progress: true
+#
+#jobs:
+# # caching of these jobs:
+# # - docker-py3-pip- (shared)
+# # - ubuntu py37 pip-
+# # - os-latest-pip- (shared)
+# min-dep-os: # min dependencies installed tests for different OS
+# runs-on: ${{ matrix.os }}
+# strategy:
+# fail-fast: false
+# matrix:
+# os: [windows-latest, macOS-latest, ubuntu-latest]
+# timeout-minutes: 40
+# steps:
+# - uses: actions/checkout@v3
+# - name: Set up Python 3.8
+# uses: actions/setup-python@v4
+# with:
+# python-version: '3.8'
+# - name: Prepare pip wheel
+# run: |
+# which python
+# python -m pip install --upgrade pip wheel
+# - name: cache weekly timestamp
+# id: pip-cache
+# run: |
+# echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
+# echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
+# shell: bash
+# - name: cache for pip
+# uses: actions/cache@v3
+# id: cache
+# with:
+# path: ${{ steps.pip-cache.outputs.dir }}
+# key: ${{ matrix.os }}-latest-pip-${{ steps.pip-cache.outputs.datew }}
+# - if: runner.os == 'windows'
+# name: Install torch cpu from pytorch.org (Windows only)
+# run: |
+# python -m pip install torch==1.13.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
+# - name: Install the dependencies
+# run: |
+# # min. requirements
+# python -m pip install torch==1.13.1
+# python -m pip install -r requirements-min.txt
+# python -m pip list
+# BUILD_MONAI=0 python setup.py develop # no compile of extensions
+# shell: bash
+# - name: Run quick tests (CPU ${{ runner.os }})
+# run: |
+# python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
+# python -c "import monai; monai.config.print_config()"
+# ./runtests.sh --min
+# shell: bash
+# env:
+# QUICKTEST: True
+#
+# min-dep-py3: # min dependencies installed tests for different python
+# runs-on: ubuntu-latest
+# strategy:
+# fail-fast: false
+# matrix:
+# python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
+# timeout-minutes: 40
+# steps:
+# - uses: actions/checkout@v3
+# - name: Set up Python ${{ matrix.python-version }}
+# uses: actions/setup-python@v4
+# with:
+# python-version: ${{ matrix.python-version }}
+# - name: Prepare pip wheel
+# run: |
+# which python
+# python -m pip install --user --upgrade pip setuptools wheel
+# - name: cache weekly timestamp
+# id: pip-cache
+# run: |
+# echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
+# echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
+# shell: bash
+# - name: cache for pip
+# uses: actions/cache@v3
+# id: cache
+# with:
+# path: ${{ steps.pip-cache.outputs.dir }}
+# key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }}
+# - name: Install the dependencies
+# run: |
+# # min. requirements
+# python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
+# python -m pip install -r requirements-min.txt
+# python -m pip list
+# BUILD_MONAI=0 python setup.py develop # no compile of extensions
+# shell: bash
+# - name: Run quick tests (CPU ${{ runner.os }})
+# run: |
+# python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
+# python -c "import monai; monai.config.print_config()"
+# ./runtests.sh --min
+# env:
+# QUICKTEST: True
+#
+# min-dep-pytorch: # min dependencies installed tests for different pytorch
+# runs-on: ubuntu-latest
+# strategy:
+# fail-fast: false
+# matrix:
+# pytorch-version: ['1.8.2', '1.9.1', '1.10.2', '1.11.0', '1.12.1', 'latest']
+# timeout-minutes: 40
+# steps:
+# - uses: actions/checkout@v3
+# - name: Set up Python 3.8
+# uses: actions/setup-python@v4
+# with:
+# python-version: '3.8'
+# - name: Prepare pip wheel
+# run: |
+# which python
+# python -m pip install --user --upgrade pip setuptools wheel
+# - name: cache weekly timestamp
+# id: pip-cache
+# run: |
+# echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
+# echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
+# shell: bash
+# - name: cache for pip
+# uses: actions/cache@v3
+# id: cache
+# with:
+# path: ${{ steps.pip-cache.outputs.dir }}
+# key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }}
+# - name: Install the dependencies
+# run: |
+# # min. requirements
+# if [ ${{ matrix.pytorch-version }} == "latest" ]; then
+# python -m pip install torch
+# elif [ ${{ matrix.pytorch-version }} == "1.8.2" ]; then
+# python -m pip install torch==1.8.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cpu
+# elif [ ${{ matrix.pytorch-version }} == "1.9.1" ]; then
+# python -m pip install torch==1.9.1
+# elif [ ${{ matrix.pytorch-version }} == "1.10.2" ]; then
+# python -m pip install torch==1.10.2
+# elif [ ${{ matrix.pytorch-version }} == "1.11.0" ]; then
+# python -m pip install torch==1.11.0
+# elif [ ${{ matrix.pytorch-version }} == "1.12.1" ]; then
+# python -m pip install torch==1.12.1
+# fi
+# python -m pip install -r requirements-min.txt
+# python -m pip list
+# BUILD_MONAI=0 python setup.py develop # no compile of extensions
+# shell: bash
+# - name: Run quick tests (pytorch ${{ matrix.pytorch-version }})
+# run: |
+# python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
+# python -c "import monai; monai.config.print_config()"
+# ./runtests.sh --min
+# env:
+# QUICKTEST: True
From fb72040822d5e1b80afd313265baf8d241f4a17a Mon Sep 17 00:00:00 2001
From: Walter Hugo Lopez Pinaya
Date: Sat, 4 Feb 2023 16:44:10 +0000
Subject: [PATCH 06/12] Remove CI
Signed-off-by: Walter Hugo Lopez Pinaya
---
.github/workflows/pythonapp-min.yml | 171 ----------------------------
1 file changed, 171 deletions(-)
delete mode 100644 .github/workflows/pythonapp-min.yml
diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml
deleted file mode 100644
index 557b5776..00000000
--- a/.github/workflows/pythonapp-min.yml
+++ /dev/null
@@ -1,171 +0,0 @@
-## Jenkinsfile.monai-premerge
-#name: premerge-min
-#
-#on:
-# # quick tests for pull requests and the releasing branches
-# push:
-# branches:
-# - main
-# pull_request:
-#
-#concurrency:
-# # automatically cancel the previously triggered workflows when there's a newer version
-# group: build-min-${{ github.event.pull_request.number || github.ref }}
-# cancel-in-progress: true
-#
-#jobs:
-# # caching of these jobs:
-# # - docker-py3-pip- (shared)
-# # - ubuntu py37 pip-
-# # - os-latest-pip- (shared)
-# min-dep-os: # min dependencies installed tests for different OS
-# runs-on: ${{ matrix.os }}
-# strategy:
-# fail-fast: false
-# matrix:
-# os: [windows-latest, macOS-latest, ubuntu-latest]
-# timeout-minutes: 40
-# steps:
-# - uses: actions/checkout@v3
-# - name: Set up Python 3.8
-# uses: actions/setup-python@v4
-# with:
-# python-version: '3.8'
-# - name: Prepare pip wheel
-# run: |
-# which python
-# python -m pip install --upgrade pip wheel
-# - name: cache weekly timestamp
-# id: pip-cache
-# run: |
-# echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
-# echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
-# shell: bash
-# - name: cache for pip
-# uses: actions/cache@v3
-# id: cache
-# with:
-# path: ${{ steps.pip-cache.outputs.dir }}
-# key: ${{ matrix.os }}-latest-pip-${{ steps.pip-cache.outputs.datew }}
-# - if: runner.os == 'windows'
-# name: Install torch cpu from pytorch.org (Windows only)
-# run: |
-# python -m pip install torch==1.13.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
-# - name: Install the dependencies
-# run: |
-# # min. requirements
-# python -m pip install torch==1.13.1
-# python -m pip install -r requirements-min.txt
-# python -m pip list
-# BUILD_MONAI=0 python setup.py develop # no compile of extensions
-# shell: bash
-# - name: Run quick tests (CPU ${{ runner.os }})
-# run: |
-# python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
-# python -c "import monai; monai.config.print_config()"
-# ./runtests.sh --min
-# shell: bash
-# env:
-# QUICKTEST: True
-#
-# min-dep-py3: # min dependencies installed tests for different python
-# runs-on: ubuntu-latest
-# strategy:
-# fail-fast: false
-# matrix:
-# python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
-# timeout-minutes: 40
-# steps:
-# - uses: actions/checkout@v3
-# - name: Set up Python ${{ matrix.python-version }}
-# uses: actions/setup-python@v4
-# with:
-# python-version: ${{ matrix.python-version }}
-# - name: Prepare pip wheel
-# run: |
-# which python
-# python -m pip install --user --upgrade pip setuptools wheel
-# - name: cache weekly timestamp
-# id: pip-cache
-# run: |
-# echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
-# echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
-# shell: bash
-# - name: cache for pip
-# uses: actions/cache@v3
-# id: cache
-# with:
-# path: ${{ steps.pip-cache.outputs.dir }}
-# key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }}
-# - name: Install the dependencies
-# run: |
-# # min. requirements
-# python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
-# python -m pip install -r requirements-min.txt
-# python -m pip list
-# BUILD_MONAI=0 python setup.py develop # no compile of extensions
-# shell: bash
-# - name: Run quick tests (CPU ${{ runner.os }})
-# run: |
-# python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
-# python -c "import monai; monai.config.print_config()"
-# ./runtests.sh --min
-# env:
-# QUICKTEST: True
-#
-# min-dep-pytorch: # min dependencies installed tests for different pytorch
-# runs-on: ubuntu-latest
-# strategy:
-# fail-fast: false
-# matrix:
-# pytorch-version: ['1.8.2', '1.9.1', '1.10.2', '1.11.0', '1.12.1', 'latest']
-# timeout-minutes: 40
-# steps:
-# - uses: actions/checkout@v3
-# - name: Set up Python 3.8
-# uses: actions/setup-python@v4
-# with:
-# python-version: '3.8'
-# - name: Prepare pip wheel
-# run: |
-# which python
-# python -m pip install --user --upgrade pip setuptools wheel
-# - name: cache weekly timestamp
-# id: pip-cache
-# run: |
-# echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
-# echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
-# shell: bash
-# - name: cache for pip
-# uses: actions/cache@v3
-# id: cache
-# with:
-# path: ${{ steps.pip-cache.outputs.dir }}
-# key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }}
-# - name: Install the dependencies
-# run: |
-# # min. requirements
-# if [ ${{ matrix.pytorch-version }} == "latest" ]; then
-# python -m pip install torch
-# elif [ ${{ matrix.pytorch-version }} == "1.8.2" ]; then
-# python -m pip install torch==1.8.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cpu
-# elif [ ${{ matrix.pytorch-version }} == "1.9.1" ]; then
-# python -m pip install torch==1.9.1
-# elif [ ${{ matrix.pytorch-version }} == "1.10.2" ]; then
-# python -m pip install torch==1.10.2
-# elif [ ${{ matrix.pytorch-version }} == "1.11.0" ]; then
-# python -m pip install torch==1.11.0
-# elif [ ${{ matrix.pytorch-version }} == "1.12.1" ]; then
-# python -m pip install torch==1.12.1
-# fi
-# python -m pip install -r requirements-min.txt
-# python -m pip list
-# BUILD_MONAI=0 python setup.py develop # no compile of extensions
-# shell: bash
-# - name: Run quick tests (pytorch ${{ matrix.pytorch-version }})
-# run: |
-# python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
-# python -c "import monai; monai.config.print_config()"
-# ./runtests.sh --min
-# env:
-# QUICKTEST: True
From 8796bab3e6490f52d917389d7b16b4b6f9bb777e Mon Sep 17 00:00:00 2001
From: Petru-Daniel Tudosiu
Date: Mon, 6 Feb 2023 11:35:27 +0000
Subject: [PATCH 07/12] Add Sequence Ordering class (#168)
* Added Ordering class and required enums as well as the tests.
* Fixes format
* Review fixes.
---
generative/utils/enums.py | 16 +-
generative/utils/ordering.py | 205 +++++++++++++++++++++++
tests/test_ordering.py | 316 +++++++++++++++++++++++++++++++++++
3 files changed, 535 insertions(+), 2 deletions(-)
create mode 100644 generative/utils/ordering.py
create mode 100644 tests/test_ordering.py
diff --git a/generative/utils/enums.py b/generative/utils/enums.py
index 327e20ea..9c510f89 100644
--- a/generative/utils/enums.py
+++ b/generative/utils/enums.py
@@ -12,7 +12,7 @@
from typing import TYPE_CHECKING
from monai.config import IgniteInfo
-from monai.utils import min_version, optional_import
+from monai.utils import StrEnum, min_version, optional_import
if TYPE_CHECKING:
from ignite.engine import EventEnum
@@ -22,7 +22,7 @@
)
-class AdversarialKeys:
+class AdversarialKeys(StrEnum):
REALS = "reals"
REAL_LOGITS = "real_logits"
FAKES = "fakes"
@@ -44,3 +44,15 @@ class AdversarialIterationEvents(EventEnum):
DISCRIMINATOR_LOSS_COMPLETED = "discriminator_loss_completed"
DISCRIMINATOR_BACKWARD_COMPLETED = "discriminator_backward_completed"
DISCRIMINATOR_MODEL_COMPLETED = "discriminator_model_completed"
+
+
+class OrderingType(StrEnum):
+ RASTER_SCAN = "raster_scan"
+ S_CURVE = "s_curve"
+ RANDOM = "random"
+
+
+class OrderingTransformations(StrEnum):
+ ROTATE_90 = "rotate_90"
+ TRANSPOSE = "transpose"
+ REFLECT = "reflect"
diff --git a/generative/utils/ordering.py b/generative/utils/ordering.py
new file mode 100644
index 00000000..bb9a6db8
--- /dev/null
+++ b/generative/utils/ordering.py
@@ -0,0 +1,205 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Tuple, Union
+
+import numpy as np
+import torch
+
+from generative.utils.enums import OrderingTransformations, OrderingType
+
+
+class Ordering:
+ """
+ Ordering class that projects a 2D or 3D image into a 1D sequence. It also allows the image to be transformed with
+ one of the following transformations:
+ - Reflection - see np.flip for more details.
+ - Transposition - see np.transpose for more details.
+ - 90-degree rotation - see np.rot90 for more details.
+
+ The transformations are applied in the order specified by the transformation_order parameter.
+
+ Args:
+ ordering_type: The ordering type. One of the following:
+ - 'raster_scan': The image is projected into a 1D sequence by scanning the image from left to right and from
+ top to bottom. Also called a row major ordering.
+ - 's_curve': The image is projected into a 1D sequence by scanning the image in a circular snake like
+ pattern from top left towards right gowing in a spiral towards the center.
+ - 'random': The image is projected into a 1D sequence by randomly shuffling the image.
+ spatial_dims: The number of spatial dimensions of the image.
+ dimensions: The dimensions of the image.
+ reflected_spatial_dims: A tuple of booleans indicating whether to reflect the image along each spatial dimension.
+ transpositions_axes: A tuple of tuples indicating the axes to transpose the image along.
+ rot90_axes: A tuple of tuples indicating the axes to rotate the image along.
+ transformation_order: The order in which to apply the transformations.
+ """
+
+ def __init__(
+ self,
+ ordering_type: str,
+ spatial_dims: int,
+ dimensions: Union[Tuple[int, int, int], Tuple[int, int, int, int]],
+ reflected_spatial_dims: Union[Tuple[bool, bool], Tuple[bool, bool, bool]] = (),
+ transpositions_axes: Union[Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int, int], ...]] = (),
+ rot90_axes: Union[Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int, int], ...]] = (),
+ transformation_order: Tuple[str, ...] = (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ ) -> None:
+ super().__init__()
+ self.ordering_type = ordering_type
+
+ if self.ordering_type not in list(OrderingType):
+ raise ValueError(
+ f"ordering_type must be one of the following {list(OrderingType)}, but got {self.ordering_type}."
+ )
+
+ self.spatial_dims = spatial_dims
+ self.dimensions = dimensions
+
+ if len(dimensions) != self.spatial_dims + 1:
+ raise ValueError(f"dimensions must be of length {self.spatial_dims + 1}, but got {len(dimensions)}.")
+
+ self.reflected_spatial_dims = reflected_spatial_dims
+ self.transpositions_axes = transpositions_axes
+ self.rot90_axes = rot90_axes
+ if len(set(transformation_order)) != len(transformation_order):
+ raise ValueError(f"No duplicates are allowed. Received {transformation_order}.")
+
+ for transformation in transformation_order:
+ if transformation not in list(OrderingTransformations):
+ raise ValueError(
+ f"Valid transformations are {list(OrderingTransformations)} but received {transformation}."
+ )
+ self.transformation_order = transformation_order
+
+ self.template = self._create_template()
+ self._sequence_ordering = self._create_ordering()
+ self._revert_sequence_ordering = np.argsort(self._sequence_ordering)
+
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
+ x = x[self._sequence_ordering]
+
+ return x
+
+ def get_sequence_ordering(self) -> np.ndarray:
+ return self._sequence_ordering
+
+ def get_revert_sequence_ordering(self) -> np.ndarray:
+ return self._revert_sequence_ordering
+
+ def _create_ordering(self) -> np.ndarray:
+ self.template = self._transform_template()
+ order = self._order_template(template=self.template)
+
+ return order
+
+ def _create_template(self) -> np.ndarray:
+ spatial_dimensions = self.dimensions[1:]
+ template = np.arange(np.prod(spatial_dimensions)).reshape(*spatial_dimensions)
+
+ return template
+
+ def _transform_template(self) -> np.ndarray:
+ for transformation in self.transformation_order:
+ if transformation == OrderingTransformations.TRANSPOSE.value:
+ self.template = self._transpose_template(template=self.template)
+ elif transformation == OrderingTransformations.ROTATE_90.value:
+ self.template = self._rot90_template(template=self.template)
+ elif transformation == OrderingTransformations.REFLECT.value:
+ self.template = self._flip_template(template=self.template)
+
+ return self.template
+
+ def _transpose_template(self, template: np.ndarray) -> np.ndarray:
+ for axes in self.transpositions_axes:
+ template = np.transpose(template, axes=axes)
+
+ return template
+
+ def _flip_template(self, template: np.ndarray) -> np.ndarray:
+ for axis, to_reflect in enumerate(self.reflected_spatial_dims):
+ template = np.flip(template, axis=axis) if to_reflect else template
+
+ return template
+
+ def _rot90_template(self, template: np.ndarray) -> np.ndarray:
+ for axes in self.rot90_axes:
+ template = np.rot90(template, axes=axes)
+
+ return template
+
+ def _order_template(self, template: np.ndarray) -> np.ndarray:
+ depths = None
+ if self.spatial_dims == 2:
+ rows, columns = template.shape[0], template.shape[1]
+ else:
+ rows, columns, depths = (template.shape[0], template.shape[1], template.shape[2])
+
+ sequence = eval(f"self.{self.ordering_type}_idx")(rows, columns, depths)
+
+ ordering = np.array([template[tuple(e)] for e in sequence])
+
+ return ordering
+
+ @staticmethod
+ def raster_scan_idx(rows: int, cols: int, depths: int = None) -> np.ndarray:
+ idx = []
+
+ for r in range(rows):
+ for c in range(cols):
+ if depths:
+ for d in range(depths):
+ idx.append((r, c, d))
+ else:
+ idx.append((r, c))
+
+ idx = np.array(idx)
+
+ return idx
+
+ @staticmethod
+ def s_curve_idx(rows: int, cols: int, depths: int = None) -> np.ndarray:
+ idx = []
+
+ for r in range(rows):
+ col_idx = range(cols) if r % 2 == 0 else range(cols - 1, -1, -1)
+ for c in col_idx:
+ if depths:
+ depth_idx = range(depths) if c % 2 == 0 else range(depths - 1, -1, -1)
+
+ for d in depth_idx:
+ idx.append((r, c, d))
+ else:
+ idx.append((r, c))
+
+ idx = np.array(idx)
+
+ return idx
+
+ @staticmethod
+ def random_idx(rows: int, cols: int, depths: int = None) -> np.ndarray:
+ idx = []
+
+ for r in range(rows):
+ for c in range(cols):
+ if depths:
+ for d in range(depths):
+ idx.append((r, c, d))
+ else:
+ idx.append((r, c))
+
+ idx = np.array(idx)
+ np.random.shuffle(idx)
+
+ return idx
diff --git a/tests/test_ordering.py b/tests/test_ordering.py
new file mode 100644
index 00000000..c40b77e9
--- /dev/null
+++ b/tests/test_ordering.py
@@ -0,0 +1,316 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from generative.utils.enums import OrderingTransformations, OrderingType
+from generative.utils.ordering import Ordering
+
+TEST_2D_NON_RANDOM = [
+ [
+ {
+ "ordering_type": OrderingType.RASTER_SCAN,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (),
+ "transpositions_axes": (),
+ "rot90_axes": (),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [0, 1, 2, 3],
+ ],
+ [
+ {
+ "ordering_type": OrderingType.S_CURVE,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (),
+ "transpositions_axes": (),
+ "rot90_axes": (),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [0, 1, 3, 2],
+ ],
+ [
+ {
+ "ordering_type": OrderingType.RASTER_SCAN,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (True, False),
+ "transpositions_axes": (),
+ "rot90_axes": (),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [2, 3, 0, 1],
+ ],
+ [
+ {
+ "ordering_type": OrderingType.S_CURVE,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (True, False),
+ "transpositions_axes": (),
+ "rot90_axes": (),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [2, 3, 1, 0],
+ ],
+ [
+ {
+ "ordering_type": OrderingType.RASTER_SCAN,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (),
+ "transpositions_axes": ((1, 0),),
+ "rot90_axes": (),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [0, 2, 1, 3],
+ ],
+ [
+ {
+ "ordering_type": OrderingType.S_CURVE,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (),
+ "transpositions_axes": ((1, 0),),
+ "rot90_axes": (),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [0, 2, 3, 1],
+ ],
+ [
+ {
+ "ordering_type": OrderingType.RASTER_SCAN,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (),
+ "transpositions_axes": (),
+ "rot90_axes": ((0, 1),),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [1, 3, 0, 2],
+ ],
+ [
+ {
+ "ordering_type": OrderingType.S_CURVE,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (),
+ "transpositions_axes": (),
+ "rot90_axes": ((0, 1),),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [1, 3, 2, 0],
+ ],
+ [
+ {
+ "ordering_type": OrderingType.RASTER_SCAN,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (True, False),
+ "transpositions_axes": ((1, 0),),
+ "rot90_axes": ((0, 1),),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [0, 1, 2, 3],
+ ],
+ [
+ {
+ "ordering_type": OrderingType.S_CURVE,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (True, False),
+ "transpositions_axes": ((1, 0),),
+ "rot90_axes": ((0, 1),),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [0, 1, 3, 2],
+ ],
+]
+
+TEST_2D_RANDOM = [
+ [
+ {
+ "ordering_type": OrderingType.RANDOM,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (True, False),
+ "transpositions_axes": ((1, 0),),
+ "rot90_axes": ((0, 1),),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [[0, 1, 2, 3], [0, 1, 3, 2]],
+ ]
+]
+
+TEST_3D = [
+ [
+ {
+ "ordering_type": OrderingType.RASTER_SCAN,
+ "spatial_dims": 3,
+ "dimensions": (1, 2, 2, 2),
+ "reflected_spatial_dims": (),
+ "transpositions_axes": (),
+ "rot90_axes": (),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [0, 1, 2, 3, 4, 5, 6, 7],
+ ]
+]
+
+TEST_ORDERING_TYPE_FAILURE = [
+ [
+ {
+ "ordering_type": "hilbert",
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (True, False),
+ "transpositions_axes": ((1, 0),),
+ "rot90_axes": ((0, 1),),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ }
+ ],
+]
+
+TEST_ORDERING_TRANSFORMATION_FAILURE = [
+ [
+ {
+ "ordering_type": OrderingType.S_CURVE,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (True, False),
+ "transpositions_axes": ((1, 0),),
+ "rot90_axes": ((0, 1),),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ "flip",
+ ),
+ }
+ ],
+]
+
+TEST_REVERT = [
+ [
+ {
+ "ordering_type": OrderingType.S_CURVE,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (True, False),
+ "transpositions_axes": (),
+ "rot90_axes": (),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ }
+ ]
+]
+
+
+class TestOrdering(unittest.TestCase):
+ @parameterized.expand(TEST_2D_NON_RANDOM + TEST_3D)
+ def test_ordering(self, input_param, expected_sequence_ordering):
+ ordering = Ordering(**input_param)
+ self.assertTrue(np.array_equal(ordering.get_sequence_ordering(), expected_sequence_ordering, equal_nan=True))
+
+ @parameterized.expand(TEST_ORDERING_TYPE_FAILURE)
+ def test_ordering_type_failure(self, input_param):
+ with self.assertRaises(ValueError):
+ Ordering(**input_param)
+
+ @parameterized.expand(TEST_ORDERING_TRANSFORMATION_FAILURE)
+ def test_ordering_transformation_failure(self, input_param):
+ with self.assertRaises(ValueError):
+ Ordering(**input_param)
+
+ @parameterized.expand(TEST_2D_RANDOM)
+ def test_random(self, input_param, not_in_expected_sequence_ordering):
+ ordering = Ordering(**input_param)
+
+ not_in = [
+ np.array_equal(sequence, ordering.get_sequence_ordering(), equal_nan=True)
+ for sequence in not_in_expected_sequence_ordering
+ ]
+
+ self.assertFalse(np.any(not_in))
+
+ @parameterized.expand(TEST_REVERT)
+ def test_revert(self, input_param):
+ sequence = np.random.randint(0, 100, size=input_param["dimensions"]).flatten()
+
+ ordering = Ordering(**input_param)
+
+ reverted_sequence = sequence[ordering.get_sequence_ordering()]
+ reverted_sequence = reverted_sequence[ordering.get_revert_sequence_ordering()]
+
+ self.assertTrue(np.array_equal(sequence, reverted_sequence, equal_nan=True))
+
+
+if __name__ == "__main__":
+ unittest.main()
From 7992e5ab74334174f80551096ce68a10e7f26733 Mon Sep 17 00:00:00 2001
From: Walter Hugo Lopez Pinaya
Date: Mon, 6 Feb 2023 14:12:14 +0000
Subject: [PATCH 08/12] Add DecoderOnlyTransformer (#225)
* Add AutoregressiveTransformer
* Add cross-attention and rename model
Signed-off-by: Walter Hugo Lopez Pinaya
---
generative/networks/nets/__init__.py | 1 +
generative/networks/nets/transformer.py | 61 +++++++++++++++++++++++++
requirements-dev.txt | 1 +
tests/min_tests.py | 1 +
tests/test_transformer.py | 42 +++++++++++++++++
5 files changed, 106 insertions(+)
create mode 100644 generative/networks/nets/transformer.py
create mode 100644 tests/test_transformer.py
diff --git a/generative/networks/nets/__init__.py b/generative/networks/nets/__init__.py
index 8f7b51d8..ed15c2f8 100644
--- a/generative/networks/nets/__init__.py
+++ b/generative/networks/nets/__init__.py
@@ -12,4 +12,5 @@
from .autoencoderkl import AutoencoderKL
from .diffusion_model_unet import DiffusionModelUNet
from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator
+from .transformer import DecoderOnlyTransformer
from .vqvae import VQVAE
diff --git a/generative/networks/nets/transformer.py b/generative/networks/nets/transformer.py
new file mode 100644
index 00000000..84476aef
--- /dev/null
+++ b/generative/networks/nets/transformer.py
@@ -0,0 +1,61 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from x_transformers import Decoder, TransformerWrapper
+
+__all__ = ["DecoderOnlyTransformer"]
+
+
+class DecoderOnlyTransformer(nn.Module):
+ """Decoder-only (Autoregressive) Transformer model.
+
+ Args:
+ num_tokens: Number of tokens in the vocabulary.
+ max_seq_len: Maximum sequence length.
+ attn_layers_dim: Dimensionality of the attention layers.
+ attn_layers_depth: Number of attention layers.
+ attn_layers_heads: Number of attention heads.
+ with_cross_attention: Whether to use cross attention for conditioning.
+ """
+
+ def __init__(
+ self,
+ num_tokens: int,
+ max_seq_len: int,
+ attn_layers_dim: int,
+ attn_layers_depth: int,
+ attn_layers_heads: int,
+ with_cross_attention: bool = False,
+ ) -> None:
+ super().__init__()
+ self.num_tokens = num_tokens
+ self.max_seq_len = max_seq_len
+ self.attn_layers_dim = attn_layers_dim
+ self.attn_layers_depth = attn_layers_depth
+ self.attn_layers_heads = attn_layers_heads
+
+ self.model = TransformerWrapper(
+ num_tokens=self.num_tokens,
+ max_seq_len=self.max_seq_len,
+ attn_layers=Decoder(
+ dim=self.attn_layers_dim,
+ depth=self.attn_layers_depth,
+ heads=self.attn_layers_heads,
+ cross_attend=with_cross_attention,
+ ),
+ )
+
+ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
+ return self.model(x, context=context)
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 1a9ddd18..2c8e5a6d 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -54,3 +54,4 @@ nni
optuna
git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
lpips==0.1.4
+x-transformers==1.8.1
diff --git a/tests/min_tests.py b/tests/min_tests.py
index 82fb1130..b4373dd8 100644
--- a/tests/min_tests.py
+++ b/tests/min_tests.py
@@ -33,6 +33,7 @@ def run_testsuit():
"test_integration_workflows_adversarial",
"test_latent_diffusion_inferer",
"test_diffusion_inferer",
+ "test_transformer",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"
diff --git a/tests/test_transformer.py b/tests/test_transformer.py
new file mode 100644
index 00000000..9ddb4ca7
--- /dev/null
+++ b/tests/test_transformer.py
@@ -0,0 +1,42 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from monai.networks import eval_mode
+
+from generative.networks.nets import DecoderOnlyTransformer
+
+
+class TestDecoderOnlyTransformer(unittest.TestCase):
+ def test_unconditioned_models(self):
+ net = DecoderOnlyTransformer(
+ num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=2
+ )
+ with eval_mode(net):
+ net.forward(torch.randint(0, 10, (1, 16)))
+
+ def test_conditioned_models(self):
+ net = DecoderOnlyTransformer(
+ num_tokens=10,
+ max_seq_len=16,
+ attn_layers_dim=8,
+ attn_layers_depth=2,
+ attn_layers_heads=2,
+ with_cross_attention=True,
+ )
+ with eval_mode(net):
+ net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 4, 8))
+
+
+if __name__ == "__main__":
+ unittest.main()
From cecdce30db3f43ef21d47451bfb4f888abf657cb Mon Sep 17 00:00:00 2001
From: Walter Hugo Lopez Pinaya
Date: Mon, 6 Feb 2023 15:07:32 +0000
Subject: [PATCH 09/12] Remove ch_mult from AutoencoderKL (#220)
* Change num_channels to Sequence
* Update tutorials
---
generative/networks/nets/autoencoderkl.py | 174 ++++++++++--------
tests/test_autoencoderkl.py | 31 ++--
tests/test_latent_diffusion_inferer.py | 3 +-
.../2d_autoencoderkl_tutorial.ipynb | 3 +-
.../2d_autoencoderkl_tutorial.py | 3 +-
.../3d_autoencoderkl_tutorial.ipynb | 3 +-
.../3d_autoencoderkl_tutorial.py | 3 +-
7 files changed, 115 insertions(+), 105 deletions(-)
diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py
index d081cad5..d13b26ad 100644
--- a/generative/networks/nets/autoencoderkl.py
+++ b/generative/networks/nets/autoencoderkl.py
@@ -296,16 +296,12 @@ class Encoder(nn.Module):
Args:
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
in_channels: number of input channels.
- num_channels: number of filters in the first downsampling.
+ num_channels: sequence of block output channels.
out_channels: number of channels in the bottom layer (latent space) of the autoencoder.
- ch_mult: list of multipliers of num_channels in the initial layer and in each downsampling layer. Example: if
- you want three downsamplings, you have to input a 4-element list. If you input [1, 1, 2, 2],
- the first downsampling will leave num_channels to num_channels, the next will multiply num_channels by 2,
- and the next will multiply num_channels*2 by 2 again, resulting in 8, 8, 16 and 32 channels.
num_res_blocks: number of residual blocks (see ResBlock) per level.
norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
norm_eps: epsilon for the normalization.
- attention_levels: indicate which level from ch_mult contain an attention block.
+ attention_levels: indicate which level from num_channels contain an attention block.
with_nonlocal_attn: if True use non-local attention block.
"""
@@ -313,20 +309,15 @@ def __init__(
self,
spatial_dims: int,
in_channels: int,
- num_channels: int,
+ num_channels: Sequence[int],
out_channels: int,
- ch_mult: Sequence[int],
num_res_blocks: int,
norm_num_groups: int,
norm_eps: float,
- attention_levels: Optional[Sequence[bool]] = None,
+ attention_levels: Sequence[bool],
with_nonlocal_attn: bool = True,
) -> None:
super().__init__()
-
- if attention_levels is None:
- attention_levels = (False,) * len(ch_mult)
-
self.spatial_dims = spatial_dims
self.in_channels = in_channels
self.num_channels = num_channels
@@ -336,15 +327,13 @@ def __init__(
self.norm_eps = norm_eps
self.attention_levels = attention_levels
- in_ch_mult = (1,) + tuple(ch_mult)
-
blocks = []
# Initial convolution
blocks.append(
Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
- out_channels=num_channels,
+ out_channels=num_channels[0],
strides=1,
kernel_size=3,
padding=1,
@@ -353,52 +342,73 @@ def __init__(
)
# Residual and downsampling blocks
- for i in range(len(ch_mult)):
- block_in_ch = num_channels * in_ch_mult[i]
- block_out_ch = num_channels * ch_mult[i]
+ output_channel = num_channels[0]
+ for i in range(len(num_channels)):
+ input_channel = output_channel
+ output_channel = num_channels[i]
+ is_final_block = i == len(num_channels) - 1
+
for _ in range(self.num_res_blocks):
blocks.append(
ResBlock(
spatial_dims=spatial_dims,
- in_channels=block_in_ch,
+ in_channels=input_channel,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
- out_channels=block_out_ch,
+ out_channels=output_channel,
)
)
- block_in_ch = block_out_ch
+ input_channel = output_channel
if attention_levels[i]:
blocks.append(
AttentionBlock(
spatial_dims=spatial_dims,
- num_channels=block_in_ch,
+ num_channels=input_channel,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
)
)
- if i != len(ch_mult) - 1:
- blocks.append(Downsample(spatial_dims, block_in_ch))
+ if not is_final_block:
+ blocks.append(Downsample(spatial_dims=spatial_dims, in_channels=input_channel))
# Non-local attention block
if with_nonlocal_attn is True:
- blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch))
blocks.append(
- AttentionBlock(
+ ResBlock(
spatial_dims=spatial_dims,
- num_channels=block_in_ch,
+ in_channels=num_channels[-1],
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
+ out_channels=num_channels[-1],
)
)
- blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch))
+ blocks.append(
+ AttentionBlock(
+ spatial_dims=spatial_dims,
+ num_channels=num_channels[-1],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ )
+ )
+ blocks.append(
+ ResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=num_channels[-1],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ out_channels=num_channels[-1],
+ )
+ )
# Normalise and convert to latent size
- blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True))
+ blocks.append(
+ nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[-1], eps=norm_eps, affine=True)
+ )
blocks.append(
Convolution(
spatial_dims=self.spatial_dims,
- in_channels=block_in_ch,
+ in_channels=num_channels[-1],
out_channels=out_channels,
strides=1,
kernel_size=3,
@@ -421,48 +431,39 @@ class Decoder(nn.Module):
Args:
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
- num_channels: number of filters in the last upsampling.
+ num_channels: sequence of block output channels.
in_channels: number of channels in the bottom layer (latent space) of the autoencoder.
out_channels: number of output channels.
- ch_mult: list of multipliers of num_channels that make for all the upsampling layers before the last. In the
- last layer, there will be a transition from num_channels to out_channels. In the layers before that,
- channels will be the product of the previous number of channels by ch_mult.
num_res_blocks: number of residual blocks (see ResBlock) per level.
norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
norm_eps: epsilon for the normalization.
- attention_levels: indicate which level from ch_mult contain an attention block.
+ attention_levels: indicate which level from num_channels contain an attention block.
with_nonlocal_attn: if True use non-local attention block.
"""
def __init__(
self,
spatial_dims: int,
- num_channels: int,
+ num_channels: Sequence[int],
in_channels: int,
out_channels: int,
- ch_mult: Sequence[int],
num_res_blocks: int,
norm_num_groups: int,
norm_eps: float,
- attention_levels: Optional[Sequence[bool]] = None,
+ attention_levels: Sequence[bool],
with_nonlocal_attn: bool = True,
) -> None:
super().__init__()
-
- if attention_levels is None:
- attention_levels = (False,) * len(ch_mult)
-
self.spatial_dims = spatial_dims
self.num_channels = num_channels
self.in_channels = in_channels
self.out_channels = out_channels
- self.ch_mult = ch_mult
self.num_res_blocks = num_res_blocks
self.norm_num_groups = norm_num_groups
self.norm_eps = norm_eps
self.attention_levels = attention_levels
- block_in_ch = num_channels * self.ch_mult[-1]
+ reversed_block_out_channels = list(reversed(num_channels))
blocks = []
# Initial convolution
@@ -470,7 +471,7 @@ def __init__(
Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
- out_channels=block_in_ch,
+ out_channels=reversed_block_out_channels[0],
strides=1,
kernel_size=3,
padding=1,
@@ -480,25 +481,53 @@ def __init__(
# Non-local attention block
if with_nonlocal_attn is True:
- blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch))
+ blocks.append(
+ ResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=reversed_block_out_channels[0],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ out_channels=reversed_block_out_channels[0],
+ )
+ )
blocks.append(
AttentionBlock(
spatial_dims=spatial_dims,
- num_channels=block_in_ch,
+ num_channels=reversed_block_out_channels[0],
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
)
)
- blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch))
+ blocks.append(
+ ResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=reversed_block_out_channels[0],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ out_channels=reversed_block_out_channels[0],
+ )
+ )
- for i in reversed(range(len(ch_mult))):
- block_out_ch = num_channels * self.ch_mult[i]
+ reversed_attention_levels = list(reversed(attention_levels))
+ block_out_ch = reversed_block_out_channels[0]
+ for i in range(len(reversed_block_out_channels)):
+ block_in_ch = block_out_ch
+ block_out_ch = reversed_block_out_channels[i]
+ is_final_block = i == len(num_channels) - 1
for _ in range(self.num_res_blocks):
- blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_out_ch))
+ blocks.append(
+ ResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=block_in_ch,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ out_channels=block_out_ch,
+ )
+ )
block_in_ch = block_out_ch
- if attention_levels[i]:
+ if reversed_attention_levels[i]:
blocks.append(
AttentionBlock(
spatial_dims=spatial_dims,
@@ -508,8 +537,8 @@ def __init__(
)
)
- if i != 0:
- blocks.append(Upsample(spatial_dims, block_in_ch))
+ if not is_final_block:
+ blocks.append(Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch))
blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True))
blocks.append(
@@ -542,14 +571,12 @@ class AutoencoderKL(nn.Module):
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
in_channels: number of input channels.
out_channels: number of output channels.
- num_channels: number of filters in the first downsampling / last upsampling.
- latent_channels: latent embedding dimension.
- ch_mult: multiplier of the number of channels in each downsampling layer (+ initial one). i.e.: If you want 3
- downsamplings, it should be a 4-element list.
num_res_blocks: number of residual blocks (see ResBlock) per level.
+ num_channels: sequence of block output channels.
+ attention_levels: sequence of levels to add attention.
+ latent_channels: latent embedding dimension.
norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
norm_eps: epsilon for the normalization.
- attention_levels: indicate which level from ch_mult contain an attention block.
with_encoder_nonlocal_attn: if True use non-local attention block in the encoder.
with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.
"""
@@ -557,35 +584,31 @@ class AutoencoderKL(nn.Module):
def __init__(
self,
spatial_dims: int,
- in_channels: int,
- out_channels: int,
- num_channels: int,
- latent_channels: int,
- ch_mult: Sequence[int],
- num_res_blocks: int,
+ in_channels: int = 1,
+ out_channels: int = 1,
+ num_res_blocks: int = 2,
+ num_channels: Sequence[int] = (32, 64, 64, 64),
+ attention_levels: Sequence[bool] = (False, False, True, True),
+ latent_channels: int = 3,
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
- attention_levels: Optional[Sequence[bool]] = None,
with_encoder_nonlocal_attn: bool = True,
with_decoder_nonlocal_attn: bool = True,
) -> None:
super().__init__()
- if attention_levels is None:
- attention_levels = (False,) * len(ch_mult)
- # The number of channels should be multiple of num_groups
- if (num_channels % norm_num_groups) != 0:
- raise ValueError("AutoencoderKL expects number of channels being multiple of number of groups")
+ # All number of channels should be multiple of num_groups
+ if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):
+ raise ValueError("AutoencoderKL expects all num_channels being multiple of norm_num_groups")
- if len(ch_mult) != len(attention_levels):
- raise ValueError("AutoencoderKL expects ch_mult being same size of attention_levels")
+ if len(num_channels) != len(attention_levels):
+ raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels")
self.encoder = Encoder(
spatial_dims=spatial_dims,
in_channels=in_channels,
num_channels=num_channels,
out_channels=latent_channels,
- ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
@@ -597,7 +620,6 @@ def __init__(
num_channels=num_channels,
in_channels=latent_channels,
out_channels=out_channels,
- ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py
index 11b1ec68..bb6af8f8 100644
--- a/tests/test_autoencoderkl.py
+++ b/tests/test_autoencoderkl.py
@@ -26,10 +26,9 @@
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
- "num_channels": 4,
+ "num_channels": (4, 4, 4),
"latent_channels": 4,
- "ch_mult": [1, 1, 1],
- "attention_levels": None,
+ "attention_levels": (False, False, False),
"num_res_blocks": 1,
"norm_num_groups": 4,
},
@@ -42,9 +41,8 @@
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
- "num_channels": 4,
+ "num_channels": (4, 4, 4),
"latent_channels": 4,
- "ch_mult": [1, 1, 1],
"attention_levels": (False, False, False),
"num_res_blocks": 1,
"norm_num_groups": 4,
@@ -58,9 +56,8 @@
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
- "num_channels": 4,
+ "num_channels": (4, 4, 4),
"latent_channels": 4,
- "ch_mult": [1, 1, 1],
"attention_levels": (False, False, True),
"num_res_blocks": 1,
"norm_num_groups": 4,
@@ -74,9 +71,8 @@
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
- "num_channels": 4,
+ "num_channels": (4, 4, 4),
"latent_channels": 4,
- "ch_mult": [1, 1, 1],
"attention_levels": (False, False, False),
"num_res_blocks": 1,
"norm_num_groups": 4,
@@ -91,9 +87,8 @@
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
- "num_channels": 4,
+ "num_channels": (4, 4, 4),
"latent_channels": 4,
- "ch_mult": [1, 1, 1],
"attention_levels": (False, False, False),
"num_res_blocks": 1,
"norm_num_groups": 4,
@@ -109,9 +104,8 @@
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 1,
- "num_channels": 4,
+ "num_channels": (4, 4, 4),
"latent_channels": 4,
- "ch_mult": [1, 1, 1],
"attention_levels": (False, False, True),
"num_res_blocks": 1,
"norm_num_groups": 4,
@@ -145,25 +139,24 @@ def test_model_channels_not_multiple_of_norm_num_group(self):
spatial_dims=2,
in_channels=1,
out_channels=1,
- num_channels=24,
+ num_channels=(24, 24, 24),
+ attention_levels=(False, False, False),
latent_channels=8,
- ch_mult=[1, 1, 1],
num_res_blocks=1,
norm_num_groups=16,
)
- def test_model_ch_mult_not_same_size_of_attention_levels(self):
+ def test_model_num_channels_not_same_size_of_attention_levels(self):
with self.assertRaises(ValueError):
AutoencoderKL(
spatial_dims=2,
in_channels=1,
out_channels=1,
- num_channels=24,
+ num_channels=(24, 24, 24),
+ attention_levels=(False, False),
latent_channels=8,
- ch_mult=[1, 1, 1],
num_res_blocks=1,
norm_num_groups=16,
- attention_levels=(True,),
)
def test_shape_reconstruction(self):
diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py
index e258f2e9..58394754 100644
--- a/tests/test_latent_diffusion_inferer.py
+++ b/tests/test_latent_diffusion_inferer.py
@@ -25,9 +25,8 @@
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
- "num_channels": 8,
+ "num_channels": (8, 8, 8),
"latent_channels": 3,
- "ch_mult": [1, 1, 1],
"attention_levels": [False, False, False],
"num_res_blocks": 1,
"with_encoder_nonlocal_attn": False,
diff --git a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb
index 5a0d0ed0..3a8c8019 100644
--- a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb
+++ b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb
@@ -598,9 +598,8 @@
" spatial_dims=2,\n",
" in_channels=1,\n",
" out_channels=1,\n",
- " num_channels=128,\n",
+ " num_channels=(128, 256, 384),\n",
" latent_channels=8,\n",
- " ch_mult=(1, 2, 3),\n",
" num_res_blocks=1,\n",
" norm_num_groups=32,\n",
" attention_levels=(False, False, True),\n",
diff --git a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py
index 9081de10..aab95a25 100644
--- a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py
+++ b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py
@@ -121,9 +121,8 @@
spatial_dims=2,
in_channels=1,
out_channels=1,
- num_channels=128,
+ num_channels=(128, 256, 384),
latent_channels=8,
- ch_mult=(1, 2, 3),
num_res_blocks=1,
norm_num_groups=32,
attention_levels=(False, False, True),
diff --git a/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.ipynb b/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.ipynb
index 8e164b82..fd2eb249 100644
--- a/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.ipynb
+++ b/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.ipynb
@@ -501,9 +501,8 @@
" spatial_dims=3,\n",
" in_channels=1,\n",
" out_channels=1,\n",
- " num_channels=32,\n",
+ " num_channels=(32, 64, 64),\n",
" latent_channels=3,\n",
- " ch_mult=(1, 2, 2),\n",
" num_res_blocks=1,\n",
" norm_num_groups=32,\n",
" attention_levels=(False, False, True),\n",
diff --git a/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.py b/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.py
index 97979189..cfce9e80 100644
--- a/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.py
+++ b/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.py
@@ -175,9 +175,8 @@
spatial_dims=3,
in_channels=1,
out_channels=1,
- num_channels=32,
+ num_channels=(32, 64, 64),
latent_channels=3,
- ch_mult=(1, 2, 2),
num_res_blocks=1,
norm_num_groups=32,
attention_levels=(False, False, True),
From 69118fac311f9ad9a10f4d0d78863892f2cf41a6 Mon Sep 17 00:00:00 2001
From: Jessica Dafflon
Date: Tue, 7 Feb 2023 12:23:01 -0500
Subject: [PATCH 10/12] Fix print messages for MS-SSIM (#230)
* Fix print messages for MS-SSIM
* Fix f-strings print (#230)
---
generative/metrics/ms_ssim.py | 18 ++++++------------
1 file changed, 6 insertions(+), 12 deletions(-)
diff --git a/generative/metrics/ms_ssim.py b/generative/metrics/ms_ssim.py
index 2a5047c7..769a7ff3 100644
--- a/generative/metrics/ms_ssim.py
+++ b/generative/metrics/ms_ssim.py
@@ -80,10 +80,7 @@ def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
if not x.shape == y.shape:
- raise ValueError(
- f"Input images should have the same dimensions, \
- but got {x.shape} and {y.shape}."
- )
+ raise ValueError(f"Input images should have the same dimensions, but got {x.shape} and {y.shape}.")
for d in range(len(x.shape) - 1, 1, -1):
x = x.squeeze(dim=d)
@@ -94,10 +91,7 @@ def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
elif len(x.shape) == 5:
avg_pool = F.avg_pool3d
else:
- raise ValueError(
- f"Input images should be 4-d or 5-d tensors, but \
- got {x.shape}"
- )
+ raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {x.shape}")
if self.weights is None:
# as per Ref 1 - Sec 3.2.
@@ -109,14 +103,14 @@ def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
for idx, shape_size in enumerate(x.shape[2:]):
if shape_size % divisible_by != 0:
raise ValueError(
- f"Image size needs to be divisible by {divisible_by} but \
- dimension {idx + 2} has size {shape_size}"
+ f"Image size needs to be divisible by {divisible_by} but "
+ f"dimension {idx + 2} has size {shape_size}"
)
if shape_size < bigger_than:
raise ValueError(
- f"Image size should be larger than {bigger_than} due to \
- the {len(self.weights) - 1} downsamplings in MS-SSIM."
+ f"Image size should be larger than {bigger_than} due to "
+ f"the {len(self.weights) - 1} downsamplings in MS-SSIM."
)
levels = self.weights.shape[0]
From eeeca8f3e76e8c2460606e50895e883aaf06d3dd Mon Sep 17 00:00:00 2001
From: Mark Graham
Date: Tue, 7 Feb 2023 13:09:20 -0600
Subject: [PATCH 11/12] Update pretrained diffusion model (#233)
* use find_unused_parameters=True now necessary for training the DDPM with DDP
* Download updated checkpoint
---
.../2d_ddpm/2d_ddpm_compare_schedulers.ipynb | 4 ++--
.../2d_ddpm/2d_ddpm_compare_schedulers.py | 4 ++--
.../generative/2d_ddpm/2d_ddpm_inpainting.ipynb | 4 ++--
.../generative/2d_ddpm/2d_ddpm_inpainting.py | 4 ++--
.../generative/2d_ddpm/2d_ddpm_tutorial.ipynb | 15 +++++++++++++--
tutorials/generative/2d_ddpm/2d_ddpm_tutorial.py | 4 ++--
.../distributed_training/ddpm_training_ddp.py | 2 +-
7 files changed, 24 insertions(+), 13 deletions(-)
diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.ipynb b/tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.ipynb
index 62de8796..1cadad2c 100644
--- a/tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.ipynb
+++ b/tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.ipynb
@@ -818,7 +818,7 @@
"use_pretrained = False\n",
"\n",
"if use_pretrained:\n",
- " model = torch.hub.load(\"marksgraham/pretrained_generative_models\", model=\"ddpm_2d\", verbose=True).to(device)\n",
+ " model = torch.hub.load(\"marksgraham/pretrained_generative_models:v0.2\", model=\"ddpm_2d\", verbose=True).to(device)\n",
"else:\n",
" n_epochs = 100\n",
" val_interval = 10\n",
@@ -1096,7 +1096,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.6"
+ "version": "3.8.13"
}
},
"nbformat": 4,
diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.py b/tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.py
index a6dfe00b..dee1bed2 100644
--- a/tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.py
+++ b/tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.py
@@ -6,7 +6,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
-# jupytext_version: 1.14.4
+# jupytext_version: 1.14.1
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
@@ -207,7 +207,7 @@
use_pretrained = False
if use_pretrained:
- model = torch.hub.load("marksgraham/pretrained_generative_models", model="ddpm_2d", verbose=True).to(device)
+ model = torch.hub.load("marksgraham/pretrained_generative_models:v0.2", model="ddpm_2d", verbose=True).to(device)
else:
n_epochs = 100
val_interval = 10
diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_inpainting.ipynb b/tutorials/generative/2d_ddpm/2d_ddpm_inpainting.ipynb
index 21d5e83f..1fc3f9f0 100644
--- a/tutorials/generative/2d_ddpm/2d_ddpm_inpainting.ipynb
+++ b/tutorials/generative/2d_ddpm/2d_ddpm_inpainting.ipynb
@@ -636,7 +636,7 @@
"use_pretrained = False\n",
"\n",
"if use_pretrained:\n",
- " model = torch.hub.load(\"marksgraham/pretrained_generative_models\", model=\"ddpm_2d\", verbose=True).to(device)\n",
+ " model = torch.hub.load(\"marksgraham/pretrained_generative_models:v0.2\", model=\"ddpm_2d\", verbose=True).to(device)\n",
"else:\n",
" n_epochs = 50\n",
" val_interval = 5\n",
@@ -914,7 +914,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.6"
+ "version": "3.8.13"
}
},
"nbformat": 4,
diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_inpainting.py b/tutorials/generative/2d_ddpm/2d_ddpm_inpainting.py
index ea6a1f8f..58874039 100644
--- a/tutorials/generative/2d_ddpm/2d_ddpm_inpainting.py
+++ b/tutorials/generative/2d_ddpm/2d_ddpm_inpainting.py
@@ -6,7 +6,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
-# jupytext_version: 1.14.4
+# jupytext_version: 1.14.1
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
@@ -191,7 +191,7 @@
use_pretrained = False
if use_pretrained:
- model = torch.hub.load("marksgraham/pretrained_generative_models", model="ddpm_2d", verbose=True).to(device)
+ model = torch.hub.load("marksgraham/pretrained_generative_models:v0.2", model="ddpm_2d", verbose=True).to(device)
else:
n_epochs = 50
val_interval = 5
diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.ipynb b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.ipynb
index 673170a8..40595441 100644
--- a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.ipynb
+++ b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.ipynb
@@ -42,6 +42,7 @@
"execution_count": 2,
"id": "dd62a552",
"metadata": {
+ "collapsed": false,
"jupyter": {
"outputs_hidden": false
}
@@ -137,6 +138,7 @@
"execution_count": 3,
"id": "8fc58c80",
"metadata": {
+ "collapsed": false,
"jupyter": {
"outputs_hidden": false
}
@@ -169,6 +171,7 @@
"execution_count": 4,
"id": "ad5a1948",
"metadata": {
+ "collapsed": false,
"jupyter": {
"outputs_hidden": false
}
@@ -194,6 +197,7 @@
"execution_count": 5,
"id": "65e1c200",
"metadata": {
+ "collapsed": false,
"jupyter": {
"outputs_hidden": false
}
@@ -232,6 +236,7 @@
"execution_count": 6,
"id": "e2f9bebd",
"metadata": {
+ "collapsed": false,
"jupyter": {
"outputs_hidden": false
}
@@ -271,6 +276,7 @@
"execution_count": 7,
"id": "938318c2",
"metadata": {
+ "collapsed": false,
"jupyter": {
"outputs_hidden": false
}
@@ -320,6 +326,7 @@
"execution_count": 8,
"id": "b698f4f8",
"metadata": {
+ "collapsed": false,
"jupyter": {
"outputs_hidden": false
}
@@ -372,6 +379,7 @@
"execution_count": 9,
"id": "2c52e4f4",
"metadata": {
+ "collapsed": false,
"jupyter": {
"outputs_hidden": false
},
@@ -415,6 +423,7 @@
"execution_count": 10,
"id": "0f697a13",
"metadata": {
+ "collapsed": false,
"jupyter": {
"outputs_hidden": false
},
@@ -763,7 +772,7 @@
"use_pretrained = False\n",
"\n",
"if use_pretrained:\n",
- " model = torch.hub.load(\"marksgraham/pretrained_generative_models\", model=\"ddpm_2d\", verbose=True).to(device)\n",
+ " model = torch.hub.load(\"marksgraham/pretrained_generative_models:v0.2\", model=\"ddpm_2d\", verbose=True).to(device)\n",
"else:\n",
" n_epochs = 75\n",
" val_interval = 5\n",
@@ -852,6 +861,7 @@
"execution_count": 11,
"id": "2cdcda81",
"metadata": {
+ "collapsed": false,
"jupyter": {
"outputs_hidden": false
}
@@ -901,6 +911,7 @@
"execution_count": 12,
"id": "1427e5d4",
"metadata": {
+ "collapsed": false,
"jupyter": {
"outputs_hidden": false
}
@@ -984,7 +995,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.6"
+ "version": "3.8.13"
}
},
"nbformat": 4,
diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.py b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.py
index 0384d33e..2d81ddb6 100644
--- a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.py
+++ b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.py
@@ -6,7 +6,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
-# jupytext_version: 1.14.4
+# jupytext_version: 1.14.1
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
@@ -190,7 +190,7 @@
use_pretrained = False
if use_pretrained:
- model = torch.hub.load("marksgraham/pretrained_generative_models", model="ddpm_2d", verbose=True).to(device)
+ model = torch.hub.load("marksgraham/pretrained_generative_models:v0.2", model="ddpm_2d", verbose=True).to(device)
else:
n_epochs = 75
val_interval = 5
diff --git a/tutorials/generative/distributed_training/ddpm_training_ddp.py b/tutorials/generative/distributed_training/ddpm_training_ddp.py
index f111d52a..07fab1b0 100644
--- a/tutorials/generative/distributed_training/ddpm_training_ddp.py
+++ b/tutorials/generative/distributed_training/ddpm_training_ddp.py
@@ -197,7 +197,7 @@ def main_worker(args):
inferer = DiffusionInferer(scheduler)
# wrap the model with DistributedDataParallel module
- model = DistributedDataParallel(model, device_ids=[device])
+ model = DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True)
# start a typical PyTorch training
best_metric = 10000
From 69ba231c4e31335faf6d182cc4f9a3a3458c0c54 Mon Sep 17 00:00:00 2001
From: OeslleLucena
Date: Tue, 24 Jan 2023 17:00:01 +0000
Subject: [PATCH 12/12] README change
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index ed2e2fbc..183f1e50 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
# MONAI Generative Models
-Prototyping repo for generative models to be integrated into MONAI core.
+Prototyping repository for generative models to be integrated into MONAI core.
## Features
* Network architectures: Diffusion Model, Autoencoder-KL, VQ-VAE, (Multi-scale) Patch-GAN discriminator.
* Diffusion Model Schedulers: DDPM, DDIM, and PNDM.