From 6c334ffe9a0fb951ea9f63161c70855e3676ef5b Mon Sep 17 00:00:00 2001 From: OeslleLucena Date: Tue, 24 Jan 2023 17:00:01 +0000 Subject: [PATCH 01/12] README change --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ed2e2fbc..183f1e50 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@

# MONAI Generative Models -Prototyping repo for generative models to be integrated into MONAI core. +Prototyping repository for generative models to be integrated into MONAI core. ## Features * Network architectures: Diffusion Model, Autoencoder-KL, VQ-VAE, (Multi-scale) Patch-GAN discriminator. * Diffusion Model Schedulers: DDPM, DDIM, and PNDM. From a6aeaf629414086655ed0b35c4e29e2622f51cdc Mon Sep 17 00:00:00 2001 From: Virginia Fernandez <61539159+virginiafdez@users.noreply.github.com> Date: Wed, 1 Feb 2023 13:15:17 +0000 Subject: [PATCH 02/12] Fix scale_factor (#214) Co-authored-by: virginiafdez --- generative/inferers/inferer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 6927f5ca..f65bdb20 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -347,14 +347,14 @@ def sample( latent = outputs with torch.no_grad(): - image = autoencoder_model.decode_stage_2_outputs(latent) * self.scale_factor + image = autoencoder_model.decode_stage_2_outputs(latent / self.scale_factor) if save_intermediates: intermediates = [] for latent_intermediate in latent_intermediates: with torch.no_grad(): intermediates.append( - autoencoder_model.decode_stage_2_outputs(latent_intermediate) * self.scale_factor + autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor) ) return image, intermediates From 4d5dac16c6079d6a7ebbf1dccd5f7ea3e54ed375 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez <61539159+virginiafdez@users.noreply.github.com> Date: Wed, 1 Feb 2023 13:28:12 +0000 Subject: [PATCH 03/12] Add is_fake_3d setting condition to error (#215) * Flag is_fake_3d has to be set to False if you want to use PerceptualLoss with 3D networks; otherwise, error happens. Modified the error in the __init__ to account for this flag setting. * Flag is_fake_3d has to be set to False if you want to use PerceptualLoss with 3D networks; otherwise, error happens. Modified the error in the __init__ to account for this flag setting. --------- Co-authored-by: virginiafdez --- generative/losses/perceptual.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py index 68b94d75..5a3640b8 100644 --- a/generative/losses/perceptual.py +++ b/generative/losses/perceptual.py @@ -45,8 +45,11 @@ def __init__( if spatial_dims not in [2, 3]: raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.") - if spatial_dims == 2 and "medicalnet_" in network_type: - raise ValueError("MedicalNet networks are only compatible with ``spatial_dims=3``.") + if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type: + raise ValueError( + "MedicalNet networks are only compatible with ``spatial_dims=3``." + "Argument is_fake_3d must be set to False." + ) self.spatial_dims = spatial_dims if spatial_dims == 3 and is_fake_3d is False: From 9ce4c1ad8d56b2018f40b139819fecc292aa1797 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 4 Feb 2023 14:11:44 +0000 Subject: [PATCH 04/12] Add check length of num_channels and attention_levels (#221) Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/diffusion_model_unet.py | 3 +++ tests/test_diffusion_model_unet.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 8d13b76f..76caa08b 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -1636,6 +1636,9 @@ def __init__( if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + if len(num_channels) != len(attention_levels): + raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels") + if isinstance(num_head_channels, int): num_head_channels = (num_head_channels,) * len(attention_levels) diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index 7e116ea6..f34b4a3f 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -360,6 +360,20 @@ def test_conditioned_models_no_class_labels(self): ) net.forward(x=torch.rand((1, 1, 16, 32)), timesteps=torch.randint(0, 1000, (1,)).long()) + def test_model_num_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + def test_script_unconditioned_2d_models(self): net = DiffusionModelUNet( spatial_dims=2, From e1fea91bea25efe20436ab2aa393b6f6982e72ab Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 4 Feb 2023 16:39:15 +0000 Subject: [PATCH 05/12] Suspend CI (#224) Signed-off-by: Walter Hugo Lopez Pinaya --- .github/workflows/pythonapp-min.yml | 342 ++++++++++++++-------------- 1 file changed, 171 insertions(+), 171 deletions(-) diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml index 0d629c29..557b5776 100644 --- a/.github/workflows/pythonapp-min.yml +++ b/.github/workflows/pythonapp-min.yml @@ -1,171 +1,171 @@ -# Jenkinsfile.monai-premerge -name: premerge-min - -on: - # quick tests for pull requests and the releasing branches - push: - branches: - - main - pull_request: - -concurrency: - # automatically cancel the previously triggered workflows when there's a newer version - group: build-min-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: true - -jobs: - # caching of these jobs: - # - docker-py3-pip- (shared) - # - ubuntu py37 pip- - # - os-latest-pip- (shared) - min-dep-os: # min dependencies installed tests for different OS - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [windows-latest, macOS-latest, ubuntu-latest] - timeout-minutes: 40 - steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.8 - uses: actions/setup-python@v4 - with: - python-version: '3.8' - - name: Prepare pip wheel - run: | - which python - python -m pip install --upgrade pip wheel - - name: cache weekly timestamp - id: pip-cache - run: | - echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT - echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - shell: bash - - name: cache for pip - uses: actions/cache@v3 - id: cache - with: - path: ${{ steps.pip-cache.outputs.dir }} - key: ${{ matrix.os }}-latest-pip-${{ steps.pip-cache.outputs.datew }} - - if: runner.os == 'windows' - name: Install torch cpu from pytorch.org (Windows only) - run: | - python -m pip install torch==1.13.1+cpu -f https://download.pytorch.org/whl/torch_stable.html - - name: Install the dependencies - run: | - # min. requirements - python -m pip install torch==1.13.1 - python -m pip install -r requirements-min.txt - python -m pip list - BUILD_MONAI=0 python setup.py develop # no compile of extensions - shell: bash - - name: Run quick tests (CPU ${{ runner.os }}) - run: | - python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' - python -c "import monai; monai.config.print_config()" - ./runtests.sh --min - shell: bash - env: - QUICKTEST: True - - min-dep-py3: # min dependencies installed tests for different python - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] - timeout-minutes: 40 - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Prepare pip wheel - run: | - which python - python -m pip install --user --upgrade pip setuptools wheel - - name: cache weekly timestamp - id: pip-cache - run: | - echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT - echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - shell: bash - - name: cache for pip - uses: actions/cache@v3 - id: cache - with: - path: ${{ steps.pip-cache.outputs.dir }} - key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }} - - name: Install the dependencies - run: | - # min. requirements - python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu - python -m pip install -r requirements-min.txt - python -m pip list - BUILD_MONAI=0 python setup.py develop # no compile of extensions - shell: bash - - name: Run quick tests (CPU ${{ runner.os }}) - run: | - python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' - python -c "import monai; monai.config.print_config()" - ./runtests.sh --min - env: - QUICKTEST: True - - min-dep-pytorch: # min dependencies installed tests for different pytorch - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - pytorch-version: ['1.8.2', '1.9.1', '1.10.2', '1.11.0', '1.12.1', 'latest'] - timeout-minutes: 40 - steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.8 - uses: actions/setup-python@v4 - with: - python-version: '3.8' - - name: Prepare pip wheel - run: | - which python - python -m pip install --user --upgrade pip setuptools wheel - - name: cache weekly timestamp - id: pip-cache - run: | - echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT - echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - shell: bash - - name: cache for pip - uses: actions/cache@v3 - id: cache - with: - path: ${{ steps.pip-cache.outputs.dir }} - key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }} - - name: Install the dependencies - run: | - # min. requirements - if [ ${{ matrix.pytorch-version }} == "latest" ]; then - python -m pip install torch - elif [ ${{ matrix.pytorch-version }} == "1.8.2" ]; then - python -m pip install torch==1.8.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cpu - elif [ ${{ matrix.pytorch-version }} == "1.9.1" ]; then - python -m pip install torch==1.9.1 - elif [ ${{ matrix.pytorch-version }} == "1.10.2" ]; then - python -m pip install torch==1.10.2 - elif [ ${{ matrix.pytorch-version }} == "1.11.0" ]; then - python -m pip install torch==1.11.0 - elif [ ${{ matrix.pytorch-version }} == "1.12.1" ]; then - python -m pip install torch==1.12.1 - fi - python -m pip install -r requirements-min.txt - python -m pip list - BUILD_MONAI=0 python setup.py develop # no compile of extensions - shell: bash - - name: Run quick tests (pytorch ${{ matrix.pytorch-version }}) - run: | - python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' - python -c "import monai; monai.config.print_config()" - ./runtests.sh --min - env: - QUICKTEST: True +## Jenkinsfile.monai-premerge +#name: premerge-min +# +#on: +# # quick tests for pull requests and the releasing branches +# push: +# branches: +# - main +# pull_request: +# +#concurrency: +# # automatically cancel the previously triggered workflows when there's a newer version +# group: build-min-${{ github.event.pull_request.number || github.ref }} +# cancel-in-progress: true +# +#jobs: +# # caching of these jobs: +# # - docker-py3-pip- (shared) +# # - ubuntu py37 pip- +# # - os-latest-pip- (shared) +# min-dep-os: # min dependencies installed tests for different OS +# runs-on: ${{ matrix.os }} +# strategy: +# fail-fast: false +# matrix: +# os: [windows-latest, macOS-latest, ubuntu-latest] +# timeout-minutes: 40 +# steps: +# - uses: actions/checkout@v3 +# - name: Set up Python 3.8 +# uses: actions/setup-python@v4 +# with: +# python-version: '3.8' +# - name: Prepare pip wheel +# run: | +# which python +# python -m pip install --upgrade pip wheel +# - name: cache weekly timestamp +# id: pip-cache +# run: | +# echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT +# echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT +# shell: bash +# - name: cache for pip +# uses: actions/cache@v3 +# id: cache +# with: +# path: ${{ steps.pip-cache.outputs.dir }} +# key: ${{ matrix.os }}-latest-pip-${{ steps.pip-cache.outputs.datew }} +# - if: runner.os == 'windows' +# name: Install torch cpu from pytorch.org (Windows only) +# run: | +# python -m pip install torch==1.13.1+cpu -f https://download.pytorch.org/whl/torch_stable.html +# - name: Install the dependencies +# run: | +# # min. requirements +# python -m pip install torch==1.13.1 +# python -m pip install -r requirements-min.txt +# python -m pip list +# BUILD_MONAI=0 python setup.py develop # no compile of extensions +# shell: bash +# - name: Run quick tests (CPU ${{ runner.os }}) +# run: | +# python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' +# python -c "import monai; monai.config.print_config()" +# ./runtests.sh --min +# shell: bash +# env: +# QUICKTEST: True +# +# min-dep-py3: # min dependencies installed tests for different python +# runs-on: ubuntu-latest +# strategy: +# fail-fast: false +# matrix: +# python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] +# timeout-minutes: 40 +# steps: +# - uses: actions/checkout@v3 +# - name: Set up Python ${{ matrix.python-version }} +# uses: actions/setup-python@v4 +# with: +# python-version: ${{ matrix.python-version }} +# - name: Prepare pip wheel +# run: | +# which python +# python -m pip install --user --upgrade pip setuptools wheel +# - name: cache weekly timestamp +# id: pip-cache +# run: | +# echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT +# echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT +# shell: bash +# - name: cache for pip +# uses: actions/cache@v3 +# id: cache +# with: +# path: ${{ steps.pip-cache.outputs.dir }} +# key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }} +# - name: Install the dependencies +# run: | +# # min. requirements +# python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu +# python -m pip install -r requirements-min.txt +# python -m pip list +# BUILD_MONAI=0 python setup.py develop # no compile of extensions +# shell: bash +# - name: Run quick tests (CPU ${{ runner.os }}) +# run: | +# python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' +# python -c "import monai; monai.config.print_config()" +# ./runtests.sh --min +# env: +# QUICKTEST: True +# +# min-dep-pytorch: # min dependencies installed tests for different pytorch +# runs-on: ubuntu-latest +# strategy: +# fail-fast: false +# matrix: +# pytorch-version: ['1.8.2', '1.9.1', '1.10.2', '1.11.0', '1.12.1', 'latest'] +# timeout-minutes: 40 +# steps: +# - uses: actions/checkout@v3 +# - name: Set up Python 3.8 +# uses: actions/setup-python@v4 +# with: +# python-version: '3.8' +# - name: Prepare pip wheel +# run: | +# which python +# python -m pip install --user --upgrade pip setuptools wheel +# - name: cache weekly timestamp +# id: pip-cache +# run: | +# echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT +# echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT +# shell: bash +# - name: cache for pip +# uses: actions/cache@v3 +# id: cache +# with: +# path: ${{ steps.pip-cache.outputs.dir }} +# key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }} +# - name: Install the dependencies +# run: | +# # min. requirements +# if [ ${{ matrix.pytorch-version }} == "latest" ]; then +# python -m pip install torch +# elif [ ${{ matrix.pytorch-version }} == "1.8.2" ]; then +# python -m pip install torch==1.8.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cpu +# elif [ ${{ matrix.pytorch-version }} == "1.9.1" ]; then +# python -m pip install torch==1.9.1 +# elif [ ${{ matrix.pytorch-version }} == "1.10.2" ]; then +# python -m pip install torch==1.10.2 +# elif [ ${{ matrix.pytorch-version }} == "1.11.0" ]; then +# python -m pip install torch==1.11.0 +# elif [ ${{ matrix.pytorch-version }} == "1.12.1" ]; then +# python -m pip install torch==1.12.1 +# fi +# python -m pip install -r requirements-min.txt +# python -m pip list +# BUILD_MONAI=0 python setup.py develop # no compile of extensions +# shell: bash +# - name: Run quick tests (pytorch ${{ matrix.pytorch-version }}) +# run: | +# python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' +# python -c "import monai; monai.config.print_config()" +# ./runtests.sh --min +# env: +# QUICKTEST: True From fb72040822d5e1b80afd313265baf8d241f4a17a Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 4 Feb 2023 16:44:10 +0000 Subject: [PATCH 06/12] Remove CI Signed-off-by: Walter Hugo Lopez Pinaya --- .github/workflows/pythonapp-min.yml | 171 ---------------------------- 1 file changed, 171 deletions(-) delete mode 100644 .github/workflows/pythonapp-min.yml diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml deleted file mode 100644 index 557b5776..00000000 --- a/.github/workflows/pythonapp-min.yml +++ /dev/null @@ -1,171 +0,0 @@ -## Jenkinsfile.monai-premerge -#name: premerge-min -# -#on: -# # quick tests for pull requests and the releasing branches -# push: -# branches: -# - main -# pull_request: -# -#concurrency: -# # automatically cancel the previously triggered workflows when there's a newer version -# group: build-min-${{ github.event.pull_request.number || github.ref }} -# cancel-in-progress: true -# -#jobs: -# # caching of these jobs: -# # - docker-py3-pip- (shared) -# # - ubuntu py37 pip- -# # - os-latest-pip- (shared) -# min-dep-os: # min dependencies installed tests for different OS -# runs-on: ${{ matrix.os }} -# strategy: -# fail-fast: false -# matrix: -# os: [windows-latest, macOS-latest, ubuntu-latest] -# timeout-minutes: 40 -# steps: -# - uses: actions/checkout@v3 -# - name: Set up Python 3.8 -# uses: actions/setup-python@v4 -# with: -# python-version: '3.8' -# - name: Prepare pip wheel -# run: | -# which python -# python -m pip install --upgrade pip wheel -# - name: cache weekly timestamp -# id: pip-cache -# run: | -# echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT -# echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT -# shell: bash -# - name: cache for pip -# uses: actions/cache@v3 -# id: cache -# with: -# path: ${{ steps.pip-cache.outputs.dir }} -# key: ${{ matrix.os }}-latest-pip-${{ steps.pip-cache.outputs.datew }} -# - if: runner.os == 'windows' -# name: Install torch cpu from pytorch.org (Windows only) -# run: | -# python -m pip install torch==1.13.1+cpu -f https://download.pytorch.org/whl/torch_stable.html -# - name: Install the dependencies -# run: | -# # min. requirements -# python -m pip install torch==1.13.1 -# python -m pip install -r requirements-min.txt -# python -m pip list -# BUILD_MONAI=0 python setup.py develop # no compile of extensions -# shell: bash -# - name: Run quick tests (CPU ${{ runner.os }}) -# run: | -# python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' -# python -c "import monai; monai.config.print_config()" -# ./runtests.sh --min -# shell: bash -# env: -# QUICKTEST: True -# -# min-dep-py3: # min dependencies installed tests for different python -# runs-on: ubuntu-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] -# timeout-minutes: 40 -# steps: -# - uses: actions/checkout@v3 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v4 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Prepare pip wheel -# run: | -# which python -# python -m pip install --user --upgrade pip setuptools wheel -# - name: cache weekly timestamp -# id: pip-cache -# run: | -# echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT -# echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT -# shell: bash -# - name: cache for pip -# uses: actions/cache@v3 -# id: cache -# with: -# path: ${{ steps.pip-cache.outputs.dir }} -# key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }} -# - name: Install the dependencies -# run: | -# # min. requirements -# python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu -# python -m pip install -r requirements-min.txt -# python -m pip list -# BUILD_MONAI=0 python setup.py develop # no compile of extensions -# shell: bash -# - name: Run quick tests (CPU ${{ runner.os }}) -# run: | -# python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' -# python -c "import monai; monai.config.print_config()" -# ./runtests.sh --min -# env: -# QUICKTEST: True -# -# min-dep-pytorch: # min dependencies installed tests for different pytorch -# runs-on: ubuntu-latest -# strategy: -# fail-fast: false -# matrix: -# pytorch-version: ['1.8.2', '1.9.1', '1.10.2', '1.11.0', '1.12.1', 'latest'] -# timeout-minutes: 40 -# steps: -# - uses: actions/checkout@v3 -# - name: Set up Python 3.8 -# uses: actions/setup-python@v4 -# with: -# python-version: '3.8' -# - name: Prepare pip wheel -# run: | -# which python -# python -m pip install --user --upgrade pip setuptools wheel -# - name: cache weekly timestamp -# id: pip-cache -# run: | -# echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT -# echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT -# shell: bash -# - name: cache for pip -# uses: actions/cache@v3 -# id: cache -# with: -# path: ${{ steps.pip-cache.outputs.dir }} -# key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }} -# - name: Install the dependencies -# run: | -# # min. requirements -# if [ ${{ matrix.pytorch-version }} == "latest" ]; then -# python -m pip install torch -# elif [ ${{ matrix.pytorch-version }} == "1.8.2" ]; then -# python -m pip install torch==1.8.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cpu -# elif [ ${{ matrix.pytorch-version }} == "1.9.1" ]; then -# python -m pip install torch==1.9.1 -# elif [ ${{ matrix.pytorch-version }} == "1.10.2" ]; then -# python -m pip install torch==1.10.2 -# elif [ ${{ matrix.pytorch-version }} == "1.11.0" ]; then -# python -m pip install torch==1.11.0 -# elif [ ${{ matrix.pytorch-version }} == "1.12.1" ]; then -# python -m pip install torch==1.12.1 -# fi -# python -m pip install -r requirements-min.txt -# python -m pip list -# BUILD_MONAI=0 python setup.py develop # no compile of extensions -# shell: bash -# - name: Run quick tests (pytorch ${{ matrix.pytorch-version }}) -# run: | -# python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' -# python -c "import monai; monai.config.print_config()" -# ./runtests.sh --min -# env: -# QUICKTEST: True From 8796bab3e6490f52d917389d7b16b4b6f9bb777e Mon Sep 17 00:00:00 2001 From: Petru-Daniel Tudosiu Date: Mon, 6 Feb 2023 11:35:27 +0000 Subject: [PATCH 07/12] Add Sequence Ordering class (#168) * Added Ordering class and required enums as well as the tests. * Fixes format * Review fixes. --- generative/utils/enums.py | 16 +- generative/utils/ordering.py | 205 +++++++++++++++++++++++ tests/test_ordering.py | 316 +++++++++++++++++++++++++++++++++++ 3 files changed, 535 insertions(+), 2 deletions(-) create mode 100644 generative/utils/ordering.py create mode 100644 tests/test_ordering.py diff --git a/generative/utils/enums.py b/generative/utils/enums.py index 327e20ea..9c510f89 100644 --- a/generative/utils/enums.py +++ b/generative/utils/enums.py @@ -12,7 +12,7 @@ from typing import TYPE_CHECKING from monai.config import IgniteInfo -from monai.utils import min_version, optional_import +from monai.utils import StrEnum, min_version, optional_import if TYPE_CHECKING: from ignite.engine import EventEnum @@ -22,7 +22,7 @@ ) -class AdversarialKeys: +class AdversarialKeys(StrEnum): REALS = "reals" REAL_LOGITS = "real_logits" FAKES = "fakes" @@ -44,3 +44,15 @@ class AdversarialIterationEvents(EventEnum): DISCRIMINATOR_LOSS_COMPLETED = "discriminator_loss_completed" DISCRIMINATOR_BACKWARD_COMPLETED = "discriminator_backward_completed" DISCRIMINATOR_MODEL_COMPLETED = "discriminator_model_completed" + + +class OrderingType(StrEnum): + RASTER_SCAN = "raster_scan" + S_CURVE = "s_curve" + RANDOM = "random" + + +class OrderingTransformations(StrEnum): + ROTATE_90 = "rotate_90" + TRANSPOSE = "transpose" + REFLECT = "reflect" diff --git a/generative/utils/ordering.py b/generative/utils/ordering.py new file mode 100644 index 00000000..bb9a6db8 --- /dev/null +++ b/generative/utils/ordering.py @@ -0,0 +1,205 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union + +import numpy as np +import torch + +from generative.utils.enums import OrderingTransformations, OrderingType + + +class Ordering: + """ + Ordering class that projects a 2D or 3D image into a 1D sequence. It also allows the image to be transformed with + one of the following transformations: + - Reflection - see np.flip for more details. + - Transposition - see np.transpose for more details. + - 90-degree rotation - see np.rot90 for more details. + + The transformations are applied in the order specified by the transformation_order parameter. + + Args: + ordering_type: The ordering type. One of the following: + - 'raster_scan': The image is projected into a 1D sequence by scanning the image from left to right and from + top to bottom. Also called a row major ordering. + - 's_curve': The image is projected into a 1D sequence by scanning the image in a circular snake like + pattern from top left towards right gowing in a spiral towards the center. + - 'random': The image is projected into a 1D sequence by randomly shuffling the image. + spatial_dims: The number of spatial dimensions of the image. + dimensions: The dimensions of the image. + reflected_spatial_dims: A tuple of booleans indicating whether to reflect the image along each spatial dimension. + transpositions_axes: A tuple of tuples indicating the axes to transpose the image along. + rot90_axes: A tuple of tuples indicating the axes to rotate the image along. + transformation_order: The order in which to apply the transformations. + """ + + def __init__( + self, + ordering_type: str, + spatial_dims: int, + dimensions: Union[Tuple[int, int, int], Tuple[int, int, int, int]], + reflected_spatial_dims: Union[Tuple[bool, bool], Tuple[bool, bool, bool]] = (), + transpositions_axes: Union[Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int, int], ...]] = (), + rot90_axes: Union[Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int, int], ...]] = (), + transformation_order: Tuple[str, ...] = ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + ) -> None: + super().__init__() + self.ordering_type = ordering_type + + if self.ordering_type not in list(OrderingType): + raise ValueError( + f"ordering_type must be one of the following {list(OrderingType)}, but got {self.ordering_type}." + ) + + self.spatial_dims = spatial_dims + self.dimensions = dimensions + + if len(dimensions) != self.spatial_dims + 1: + raise ValueError(f"dimensions must be of length {self.spatial_dims + 1}, but got {len(dimensions)}.") + + self.reflected_spatial_dims = reflected_spatial_dims + self.transpositions_axes = transpositions_axes + self.rot90_axes = rot90_axes + if len(set(transformation_order)) != len(transformation_order): + raise ValueError(f"No duplicates are allowed. Received {transformation_order}.") + + for transformation in transformation_order: + if transformation not in list(OrderingTransformations): + raise ValueError( + f"Valid transformations are {list(OrderingTransformations)} but received {transformation}." + ) + self.transformation_order = transformation_order + + self.template = self._create_template() + self._sequence_ordering = self._create_ordering() + self._revert_sequence_ordering = np.argsort(self._sequence_ordering) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + x = x[self._sequence_ordering] + + return x + + def get_sequence_ordering(self) -> np.ndarray: + return self._sequence_ordering + + def get_revert_sequence_ordering(self) -> np.ndarray: + return self._revert_sequence_ordering + + def _create_ordering(self) -> np.ndarray: + self.template = self._transform_template() + order = self._order_template(template=self.template) + + return order + + def _create_template(self) -> np.ndarray: + spatial_dimensions = self.dimensions[1:] + template = np.arange(np.prod(spatial_dimensions)).reshape(*spatial_dimensions) + + return template + + def _transform_template(self) -> np.ndarray: + for transformation in self.transformation_order: + if transformation == OrderingTransformations.TRANSPOSE.value: + self.template = self._transpose_template(template=self.template) + elif transformation == OrderingTransformations.ROTATE_90.value: + self.template = self._rot90_template(template=self.template) + elif transformation == OrderingTransformations.REFLECT.value: + self.template = self._flip_template(template=self.template) + + return self.template + + def _transpose_template(self, template: np.ndarray) -> np.ndarray: + for axes in self.transpositions_axes: + template = np.transpose(template, axes=axes) + + return template + + def _flip_template(self, template: np.ndarray) -> np.ndarray: + for axis, to_reflect in enumerate(self.reflected_spatial_dims): + template = np.flip(template, axis=axis) if to_reflect else template + + return template + + def _rot90_template(self, template: np.ndarray) -> np.ndarray: + for axes in self.rot90_axes: + template = np.rot90(template, axes=axes) + + return template + + def _order_template(self, template: np.ndarray) -> np.ndarray: + depths = None + if self.spatial_dims == 2: + rows, columns = template.shape[0], template.shape[1] + else: + rows, columns, depths = (template.shape[0], template.shape[1], template.shape[2]) + + sequence = eval(f"self.{self.ordering_type}_idx")(rows, columns, depths) + + ordering = np.array([template[tuple(e)] for e in sequence]) + + return ordering + + @staticmethod + def raster_scan_idx(rows: int, cols: int, depths: int = None) -> np.ndarray: + idx = [] + + for r in range(rows): + for c in range(cols): + if depths: + for d in range(depths): + idx.append((r, c, d)) + else: + idx.append((r, c)) + + idx = np.array(idx) + + return idx + + @staticmethod + def s_curve_idx(rows: int, cols: int, depths: int = None) -> np.ndarray: + idx = [] + + for r in range(rows): + col_idx = range(cols) if r % 2 == 0 else range(cols - 1, -1, -1) + for c in col_idx: + if depths: + depth_idx = range(depths) if c % 2 == 0 else range(depths - 1, -1, -1) + + for d in depth_idx: + idx.append((r, c, d)) + else: + idx.append((r, c)) + + idx = np.array(idx) + + return idx + + @staticmethod + def random_idx(rows: int, cols: int, depths: int = None) -> np.ndarray: + idx = [] + + for r in range(rows): + for c in range(cols): + if depths: + for d in range(depths): + idx.append((r, c, d)) + else: + idx.append((r, c)) + + idx = np.array(idx) + np.random.shuffle(idx) + + return idx diff --git a/tests/test_ordering.py b/tests/test_ordering.py new file mode 100644 index 00000000..c40b77e9 --- /dev/null +++ b/tests/test_ordering.py @@ -0,0 +1,316 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from generative.utils.enums import OrderingTransformations, OrderingType +from generative.utils.ordering import Ordering + +TEST_2D_NON_RANDOM = [ + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 2, 3], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 3, 2], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [2, 3, 0, 1], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [2, 3, 1, 0], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": ((1, 0),), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 2, 1, 3], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": ((1, 0),), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 2, 3, 1], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [1, 3, 0, 2], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [1, 3, 2, 0], + ], + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 2, 3], + ], + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 3, 2], + ], +] + +TEST_2D_RANDOM = [ + [ + { + "ordering_type": OrderingType.RANDOM, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [[0, 1, 2, 3], [0, 1, 3, 2]], + ] +] + +TEST_3D = [ + [ + { + "ordering_type": OrderingType.RASTER_SCAN, + "spatial_dims": 3, + "dimensions": (1, 2, 2, 2), + "reflected_spatial_dims": (), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + }, + [0, 1, 2, 3, 4, 5, 6, 7], + ] +] + +TEST_ORDERING_TYPE_FAILURE = [ + [ + { + "ordering_type": "hilbert", + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + } + ], +] + +TEST_ORDERING_TRANSFORMATION_FAILURE = [ + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": ((1, 0),), + "rot90_axes": ((0, 1),), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + "flip", + ), + } + ], +] + +TEST_REVERT = [ + [ + { + "ordering_type": OrderingType.S_CURVE, + "spatial_dims": 2, + "dimensions": (1, 2, 2), + "reflected_spatial_dims": (True, False), + "transpositions_axes": (), + "rot90_axes": (), + "transformation_order": ( + OrderingTransformations.TRANSPOSE.value, + OrderingTransformations.ROTATE_90.value, + OrderingTransformations.REFLECT.value, + ), + } + ] +] + + +class TestOrdering(unittest.TestCase): + @parameterized.expand(TEST_2D_NON_RANDOM + TEST_3D) + def test_ordering(self, input_param, expected_sequence_ordering): + ordering = Ordering(**input_param) + self.assertTrue(np.array_equal(ordering.get_sequence_ordering(), expected_sequence_ordering, equal_nan=True)) + + @parameterized.expand(TEST_ORDERING_TYPE_FAILURE) + def test_ordering_type_failure(self, input_param): + with self.assertRaises(ValueError): + Ordering(**input_param) + + @parameterized.expand(TEST_ORDERING_TRANSFORMATION_FAILURE) + def test_ordering_transformation_failure(self, input_param): + with self.assertRaises(ValueError): + Ordering(**input_param) + + @parameterized.expand(TEST_2D_RANDOM) + def test_random(self, input_param, not_in_expected_sequence_ordering): + ordering = Ordering(**input_param) + + not_in = [ + np.array_equal(sequence, ordering.get_sequence_ordering(), equal_nan=True) + for sequence in not_in_expected_sequence_ordering + ] + + self.assertFalse(np.any(not_in)) + + @parameterized.expand(TEST_REVERT) + def test_revert(self, input_param): + sequence = np.random.randint(0, 100, size=input_param["dimensions"]).flatten() + + ordering = Ordering(**input_param) + + reverted_sequence = sequence[ordering.get_sequence_ordering()] + reverted_sequence = reverted_sequence[ordering.get_revert_sequence_ordering()] + + self.assertTrue(np.array_equal(sequence, reverted_sequence, equal_nan=True)) + + +if __name__ == "__main__": + unittest.main() From 7992e5ab74334174f80551096ce68a10e7f26733 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 6 Feb 2023 14:12:14 +0000 Subject: [PATCH 08/12] Add DecoderOnlyTransformer (#225) * Add AutoregressiveTransformer * Add cross-attention and rename model Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/__init__.py | 1 + generative/networks/nets/transformer.py | 61 +++++++++++++++++++++++++ requirements-dev.txt | 1 + tests/min_tests.py | 1 + tests/test_transformer.py | 42 +++++++++++++++++ 5 files changed, 106 insertions(+) create mode 100644 generative/networks/nets/transformer.py create mode 100644 tests/test_transformer.py diff --git a/generative/networks/nets/__init__.py b/generative/networks/nets/__init__.py index 8f7b51d8..ed15c2f8 100644 --- a/generative/networks/nets/__init__.py +++ b/generative/networks/nets/__init__.py @@ -12,4 +12,5 @@ from .autoencoderkl import AutoencoderKL from .diffusion_model_unet import DiffusionModelUNet from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator +from .transformer import DecoderOnlyTransformer from .vqvae import VQVAE diff --git a/generative/networks/nets/transformer.py b/generative/networks/nets/transformer.py new file mode 100644 index 00000000..84476aef --- /dev/null +++ b/generative/networks/nets/transformer.py @@ -0,0 +1,61 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import torch.nn as nn +from x_transformers import Decoder, TransformerWrapper + +__all__ = ["DecoderOnlyTransformer"] + + +class DecoderOnlyTransformer(nn.Module): + """Decoder-only (Autoregressive) Transformer model. + + Args: + num_tokens: Number of tokens in the vocabulary. + max_seq_len: Maximum sequence length. + attn_layers_dim: Dimensionality of the attention layers. + attn_layers_depth: Number of attention layers. + attn_layers_heads: Number of attention heads. + with_cross_attention: Whether to use cross attention for conditioning. + """ + + def __init__( + self, + num_tokens: int, + max_seq_len: int, + attn_layers_dim: int, + attn_layers_depth: int, + attn_layers_heads: int, + with_cross_attention: bool = False, + ) -> None: + super().__init__() + self.num_tokens = num_tokens + self.max_seq_len = max_seq_len + self.attn_layers_dim = attn_layers_dim + self.attn_layers_depth = attn_layers_depth + self.attn_layers_heads = attn_layers_heads + + self.model = TransformerWrapper( + num_tokens=self.num_tokens, + max_seq_len=self.max_seq_len, + attn_layers=Decoder( + dim=self.attn_layers_dim, + depth=self.attn_layers_depth, + heads=self.attn_layers_heads, + cross_attend=with_cross_attention, + ), + ) + + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: + return self.model(x, context=context) diff --git a/requirements-dev.txt b/requirements-dev.txt index 1a9ddd18..2c8e5a6d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -54,3 +54,4 @@ nni optuna git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded lpips==0.1.4 +x-transformers==1.8.1 diff --git a/tests/min_tests.py b/tests/min_tests.py index 82fb1130..b4373dd8 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -33,6 +33,7 @@ def run_testsuit(): "test_integration_workflows_adversarial", "test_latent_diffusion_inferer", "test_diffusion_inferer", + "test_transformer", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_transformer.py b/tests/test_transformer.py new file mode 100644 index 00000000..9ddb4ca7 --- /dev/null +++ b/tests/test_transformer.py @@ -0,0 +1,42 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from monai.networks import eval_mode + +from generative.networks.nets import DecoderOnlyTransformer + + +class TestDecoderOnlyTransformer(unittest.TestCase): + def test_unconditioned_models(self): + net = DecoderOnlyTransformer( + num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=2 + ) + with eval_mode(net): + net.forward(torch.randint(0, 10, (1, 16))) + + def test_conditioned_models(self): + net = DecoderOnlyTransformer( + num_tokens=10, + max_seq_len=16, + attn_layers_dim=8, + attn_layers_depth=2, + attn_layers_heads=2, + with_cross_attention=True, + ) + with eval_mode(net): + net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 4, 8)) + + +if __name__ == "__main__": + unittest.main() From cecdce30db3f43ef21d47451bfb4f888abf657cb Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 6 Feb 2023 15:07:32 +0000 Subject: [PATCH 09/12] Remove ch_mult from AutoencoderKL (#220) * Change num_channels to Sequence * Update tutorials --- generative/networks/nets/autoencoderkl.py | 174 ++++++++++-------- tests/test_autoencoderkl.py | 31 ++-- tests/test_latent_diffusion_inferer.py | 3 +- .../2d_autoencoderkl_tutorial.ipynb | 3 +- .../2d_autoencoderkl_tutorial.py | 3 +- .../3d_autoencoderkl_tutorial.ipynb | 3 +- .../3d_autoencoderkl_tutorial.py | 3 +- 7 files changed, 115 insertions(+), 105 deletions(-) diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index d081cad5..d13b26ad 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -296,16 +296,12 @@ class Encoder(nn.Module): Args: spatial_dims: number of spatial dimensions (1D, 2D, 3D). in_channels: number of input channels. - num_channels: number of filters in the first downsampling. + num_channels: sequence of block output channels. out_channels: number of channels in the bottom layer (latent space) of the autoencoder. - ch_mult: list of multipliers of num_channels in the initial layer and in each downsampling layer. Example: if - you want three downsamplings, you have to input a 4-element list. If you input [1, 1, 2, 2], - the first downsampling will leave num_channels to num_channels, the next will multiply num_channels by 2, - and the next will multiply num_channels*2 by 2 again, resulting in 8, 8, 16 and 32 channels. num_res_blocks: number of residual blocks (see ResBlock) per level. norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. norm_eps: epsilon for the normalization. - attention_levels: indicate which level from ch_mult contain an attention block. + attention_levels: indicate which level from num_channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. """ @@ -313,20 +309,15 @@ def __init__( self, spatial_dims: int, in_channels: int, - num_channels: int, + num_channels: Sequence[int], out_channels: int, - ch_mult: Sequence[int], num_res_blocks: int, norm_num_groups: int, norm_eps: float, - attention_levels: Optional[Sequence[bool]] = None, + attention_levels: Sequence[bool], with_nonlocal_attn: bool = True, ) -> None: super().__init__() - - if attention_levels is None: - attention_levels = (False,) * len(ch_mult) - self.spatial_dims = spatial_dims self.in_channels = in_channels self.num_channels = num_channels @@ -336,15 +327,13 @@ def __init__( self.norm_eps = norm_eps self.attention_levels = attention_levels - in_ch_mult = (1,) + tuple(ch_mult) - blocks = [] # Initial convolution blocks.append( Convolution( spatial_dims=spatial_dims, in_channels=in_channels, - out_channels=num_channels, + out_channels=num_channels[0], strides=1, kernel_size=3, padding=1, @@ -353,52 +342,73 @@ def __init__( ) # Residual and downsampling blocks - for i in range(len(ch_mult)): - block_in_ch = num_channels * in_ch_mult[i] - block_out_ch = num_channels * ch_mult[i] + output_channel = num_channels[0] + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) - 1 + for _ in range(self.num_res_blocks): blocks.append( ResBlock( spatial_dims=spatial_dims, - in_channels=block_in_ch, + in_channels=input_channel, norm_num_groups=norm_num_groups, norm_eps=norm_eps, - out_channels=block_out_ch, + out_channels=output_channel, ) ) - block_in_ch = block_out_ch + input_channel = output_channel if attention_levels[i]: blocks.append( AttentionBlock( spatial_dims=spatial_dims, - num_channels=block_in_ch, + num_channels=input_channel, norm_num_groups=norm_num_groups, norm_eps=norm_eps, ) ) - if i != len(ch_mult) - 1: - blocks.append(Downsample(spatial_dims, block_in_ch)) + if not is_final_block: + blocks.append(Downsample(spatial_dims=spatial_dims, in_channels=input_channel)) # Non-local attention block if with_nonlocal_attn is True: - blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch)) blocks.append( - AttentionBlock( + ResBlock( spatial_dims=spatial_dims, - num_channels=block_in_ch, + in_channels=num_channels[-1], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + out_channels=num_channels[-1], ) ) - blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch)) + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=num_channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=num_channels[-1], + ) + ) # Normalise and convert to latent size - blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) + blocks.append( + nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[-1], eps=norm_eps, affine=True) + ) blocks.append( Convolution( spatial_dims=self.spatial_dims, - in_channels=block_in_ch, + in_channels=num_channels[-1], out_channels=out_channels, strides=1, kernel_size=3, @@ -421,48 +431,39 @@ class Decoder(nn.Module): Args: spatial_dims: number of spatial dimensions (1D, 2D, 3D). - num_channels: number of filters in the last upsampling. + num_channels: sequence of block output channels. in_channels: number of channels in the bottom layer (latent space) of the autoencoder. out_channels: number of output channels. - ch_mult: list of multipliers of num_channels that make for all the upsampling layers before the last. In the - last layer, there will be a transition from num_channels to out_channels. In the layers before that, - channels will be the product of the previous number of channels by ch_mult. num_res_blocks: number of residual blocks (see ResBlock) per level. norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. norm_eps: epsilon for the normalization. - attention_levels: indicate which level from ch_mult contain an attention block. + attention_levels: indicate which level from num_channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. """ def __init__( self, spatial_dims: int, - num_channels: int, + num_channels: Sequence[int], in_channels: int, out_channels: int, - ch_mult: Sequence[int], num_res_blocks: int, norm_num_groups: int, norm_eps: float, - attention_levels: Optional[Sequence[bool]] = None, + attention_levels: Sequence[bool], with_nonlocal_attn: bool = True, ) -> None: super().__init__() - - if attention_levels is None: - attention_levels = (False,) * len(ch_mult) - self.spatial_dims = spatial_dims self.num_channels = num_channels self.in_channels = in_channels self.out_channels = out_channels - self.ch_mult = ch_mult self.num_res_blocks = num_res_blocks self.norm_num_groups = norm_num_groups self.norm_eps = norm_eps self.attention_levels = attention_levels - block_in_ch = num_channels * self.ch_mult[-1] + reversed_block_out_channels = list(reversed(num_channels)) blocks = [] # Initial convolution @@ -470,7 +471,7 @@ def __init__( Convolution( spatial_dims=spatial_dims, in_channels=in_channels, - out_channels=block_in_ch, + out_channels=reversed_block_out_channels[0], strides=1, kernel_size=3, padding=1, @@ -480,25 +481,53 @@ def __init__( # Non-local attention block if with_nonlocal_attn is True: - blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch)) + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) blocks.append( AttentionBlock( spatial_dims=spatial_dims, - num_channels=block_in_ch, + num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, norm_eps=norm_eps, ) ) - blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch)) + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) - for i in reversed(range(len(ch_mult))): - block_out_ch = num_channels * self.ch_mult[i] + reversed_attention_levels = list(reversed(attention_levels)) + block_out_ch = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + block_in_ch = block_out_ch + block_out_ch = reversed_block_out_channels[i] + is_final_block = i == len(num_channels) - 1 for _ in range(self.num_res_blocks): - blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_out_ch)) + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=block_out_ch, + ) + ) block_in_ch = block_out_ch - if attention_levels[i]: + if reversed_attention_levels[i]: blocks.append( AttentionBlock( spatial_dims=spatial_dims, @@ -508,8 +537,8 @@ def __init__( ) ) - if i != 0: - blocks.append(Upsample(spatial_dims, block_in_ch)) + if not is_final_block: + blocks.append(Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch)) blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) blocks.append( @@ -542,14 +571,12 @@ class AutoencoderKL(nn.Module): spatial_dims: number of spatial dimensions (1D, 2D, 3D). in_channels: number of input channels. out_channels: number of output channels. - num_channels: number of filters in the first downsampling / last upsampling. - latent_channels: latent embedding dimension. - ch_mult: multiplier of the number of channels in each downsampling layer (+ initial one). i.e.: If you want 3 - downsamplings, it should be a 4-element list. num_res_blocks: number of residual blocks (see ResBlock) per level. + num_channels: sequence of block output channels. + attention_levels: sequence of levels to add attention. + latent_channels: latent embedding dimension. norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. norm_eps: epsilon for the normalization. - attention_levels: indicate which level from ch_mult contain an attention block. with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. """ @@ -557,35 +584,31 @@ class AutoencoderKL(nn.Module): def __init__( self, spatial_dims: int, - in_channels: int, - out_channels: int, - num_channels: int, - latent_channels: int, - ch_mult: Sequence[int], - num_res_blocks: int, + in_channels: int = 1, + out_channels: int = 1, + num_res_blocks: int = 2, + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + latent_channels: int = 3, norm_num_groups: int = 32, norm_eps: float = 1e-6, - attention_levels: Optional[Sequence[bool]] = None, with_encoder_nonlocal_attn: bool = True, with_decoder_nonlocal_attn: bool = True, ) -> None: super().__init__() - if attention_levels is None: - attention_levels = (False,) * len(ch_mult) - # The number of channels should be multiple of num_groups - if (num_channels % norm_num_groups) != 0: - raise ValueError("AutoencoderKL expects number of channels being multiple of number of groups") + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + raise ValueError("AutoencoderKL expects all num_channels being multiple of norm_num_groups") - if len(ch_mult) != len(attention_levels): - raise ValueError("AutoencoderKL expects ch_mult being same size of attention_levels") + if len(num_channels) != len(attention_levels): + raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels") self.encoder = Encoder( spatial_dims=spatial_dims, in_channels=in_channels, num_channels=num_channels, out_channels=latent_channels, - ch_mult=ch_mult, num_res_blocks=num_res_blocks, norm_num_groups=norm_num_groups, norm_eps=norm_eps, @@ -597,7 +620,6 @@ def __init__( num_channels=num_channels, in_channels=latent_channels, out_channels=out_channels, - ch_mult=ch_mult, num_res_blocks=num_res_blocks, norm_num_groups=norm_num_groups, norm_eps=norm_eps, diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py index 11b1ec68..bb6af8f8 100644 --- a/tests/test_autoencoderkl.py +++ b/tests/test_autoencoderkl.py @@ -26,10 +26,9 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": 4, + "num_channels": (4, 4, 4), "latent_channels": 4, - "ch_mult": [1, 1, 1], - "attention_levels": None, + "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, }, @@ -42,9 +41,8 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": 4, + "num_channels": (4, 4, 4), "latent_channels": 4, - "ch_mult": [1, 1, 1], "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, @@ -58,9 +56,8 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": 4, + "num_channels": (4, 4, 4), "latent_channels": 4, - "ch_mult": [1, 1, 1], "attention_levels": (False, False, True), "num_res_blocks": 1, "norm_num_groups": 4, @@ -74,9 +71,8 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": 4, + "num_channels": (4, 4, 4), "latent_channels": 4, - "ch_mult": [1, 1, 1], "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, @@ -91,9 +87,8 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": 4, + "num_channels": (4, 4, 4), "latent_channels": 4, - "ch_mult": [1, 1, 1], "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, @@ -109,9 +104,8 @@ "spatial_dims": 3, "in_channels": 1, "out_channels": 1, - "num_channels": 4, + "num_channels": (4, 4, 4), "latent_channels": 4, - "ch_mult": [1, 1, 1], "attention_levels": (False, False, True), "num_res_blocks": 1, "norm_num_groups": 4, @@ -145,25 +139,24 @@ def test_model_channels_not_multiple_of_norm_num_group(self): spatial_dims=2, in_channels=1, out_channels=1, - num_channels=24, + num_channels=(24, 24, 24), + attention_levels=(False, False, False), latent_channels=8, - ch_mult=[1, 1, 1], num_res_blocks=1, norm_num_groups=16, ) - def test_model_ch_mult_not_same_size_of_attention_levels(self): + def test_model_num_channels_not_same_size_of_attention_levels(self): with self.assertRaises(ValueError): AutoencoderKL( spatial_dims=2, in_channels=1, out_channels=1, - num_channels=24, + num_channels=(24, 24, 24), + attention_levels=(False, False), latent_channels=8, - ch_mult=[1, 1, 1], num_res_blocks=1, norm_num_groups=16, - attention_levels=(True,), ) def test_shape_reconstruction(self): diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index e258f2e9..58394754 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -25,9 +25,8 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": 8, + "num_channels": (8, 8, 8), "latent_channels": 3, - "ch_mult": [1, 1, 1], "attention_levels": [False, False, False], "num_res_blocks": 1, "with_encoder_nonlocal_attn": False, diff --git a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb index 5a0d0ed0..3a8c8019 100644 --- a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb +++ b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb @@ -598,9 +598,8 @@ " spatial_dims=2,\n", " in_channels=1,\n", " out_channels=1,\n", - " num_channels=128,\n", + " num_channels=(128, 256, 384),\n", " latent_channels=8,\n", - " ch_mult=(1, 2, 3),\n", " num_res_blocks=1,\n", " norm_num_groups=32,\n", " attention_levels=(False, False, True),\n", diff --git a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py index 9081de10..aab95a25 100644 --- a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py +++ b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py @@ -121,9 +121,8 @@ spatial_dims=2, in_channels=1, out_channels=1, - num_channels=128, + num_channels=(128, 256, 384), latent_channels=8, - ch_mult=(1, 2, 3), num_res_blocks=1, norm_num_groups=32, attention_levels=(False, False, True), diff --git a/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.ipynb b/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.ipynb index 8e164b82..fd2eb249 100644 --- a/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.ipynb +++ b/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.ipynb @@ -501,9 +501,8 @@ " spatial_dims=3,\n", " in_channels=1,\n", " out_channels=1,\n", - " num_channels=32,\n", + " num_channels=(32, 64, 64),\n", " latent_channels=3,\n", - " ch_mult=(1, 2, 2),\n", " num_res_blocks=1,\n", " norm_num_groups=32,\n", " attention_levels=(False, False, True),\n", diff --git a/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.py b/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.py index 97979189..cfce9e80 100644 --- a/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.py +++ b/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.py @@ -175,9 +175,8 @@ spatial_dims=3, in_channels=1, out_channels=1, - num_channels=32, + num_channels=(32, 64, 64), latent_channels=3, - ch_mult=(1, 2, 2), num_res_blocks=1, norm_num_groups=32, attention_levels=(False, False, True), From 69118fac311f9ad9a10f4d0d78863892f2cf41a6 Mon Sep 17 00:00:00 2001 From: Jessica Dafflon Date: Tue, 7 Feb 2023 12:23:01 -0500 Subject: [PATCH 10/12] Fix print messages for MS-SSIM (#230) * Fix print messages for MS-SSIM * Fix f-strings print (#230) --- generative/metrics/ms_ssim.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/generative/metrics/ms_ssim.py b/generative/metrics/ms_ssim.py index 2a5047c7..769a7ff3 100644 --- a/generative/metrics/ms_ssim.py +++ b/generative/metrics/ms_ssim.py @@ -80,10 +80,7 @@ def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ if not x.shape == y.shape: - raise ValueError( - f"Input images should have the same dimensions, \ - but got {x.shape} and {y.shape}." - ) + raise ValueError(f"Input images should have the same dimensions, but got {x.shape} and {y.shape}.") for d in range(len(x.shape) - 1, 1, -1): x = x.squeeze(dim=d) @@ -94,10 +91,7 @@ def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: elif len(x.shape) == 5: avg_pool = F.avg_pool3d else: - raise ValueError( - f"Input images should be 4-d or 5-d tensors, but \ - got {x.shape}" - ) + raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {x.shape}") if self.weights is None: # as per Ref 1 - Sec 3.2. @@ -109,14 +103,14 @@ def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: for idx, shape_size in enumerate(x.shape[2:]): if shape_size % divisible_by != 0: raise ValueError( - f"Image size needs to be divisible by {divisible_by} but \ - dimension {idx + 2} has size {shape_size}" + f"Image size needs to be divisible by {divisible_by} but " + f"dimension {idx + 2} has size {shape_size}" ) if shape_size < bigger_than: raise ValueError( - f"Image size should be larger than {bigger_than} due to \ - the {len(self.weights) - 1} downsamplings in MS-SSIM." + f"Image size should be larger than {bigger_than} due to " + f"the {len(self.weights) - 1} downsamplings in MS-SSIM." ) levels = self.weights.shape[0] From eeeca8f3e76e8c2460606e50895e883aaf06d3dd Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 7 Feb 2023 13:09:20 -0600 Subject: [PATCH 11/12] Update pretrained diffusion model (#233) * use find_unused_parameters=True now necessary for training the DDPM with DDP * Download updated checkpoint --- .../2d_ddpm/2d_ddpm_compare_schedulers.ipynb | 4 ++-- .../2d_ddpm/2d_ddpm_compare_schedulers.py | 4 ++-- .../generative/2d_ddpm/2d_ddpm_inpainting.ipynb | 4 ++-- .../generative/2d_ddpm/2d_ddpm_inpainting.py | 4 ++-- .../generative/2d_ddpm/2d_ddpm_tutorial.ipynb | 15 +++++++++++++-- tutorials/generative/2d_ddpm/2d_ddpm_tutorial.py | 4 ++-- .../distributed_training/ddpm_training_ddp.py | 2 +- 7 files changed, 24 insertions(+), 13 deletions(-) diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.ipynb b/tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.ipynb index 62de8796..1cadad2c 100644 --- a/tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.ipynb +++ b/tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.ipynb @@ -818,7 +818,7 @@ "use_pretrained = False\n", "\n", "if use_pretrained:\n", - " model = torch.hub.load(\"marksgraham/pretrained_generative_models\", model=\"ddpm_2d\", verbose=True).to(device)\n", + " model = torch.hub.load(\"marksgraham/pretrained_generative_models:v0.2\", model=\"ddpm_2d\", verbose=True).to(device)\n", "else:\n", " n_epochs = 100\n", " val_interval = 10\n", @@ -1096,7 +1096,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.8.13" } }, "nbformat": 4, diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.py b/tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.py index a6dfe00b..dee1bed2 100644 --- a/tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.py +++ b/tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.py @@ -6,7 +6,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.14.4 +# jupytext_version: 1.14.1 # kernelspec: # display_name: Python 3 (ipykernel) # language: python @@ -207,7 +207,7 @@ use_pretrained = False if use_pretrained: - model = torch.hub.load("marksgraham/pretrained_generative_models", model="ddpm_2d", verbose=True).to(device) + model = torch.hub.load("marksgraham/pretrained_generative_models:v0.2", model="ddpm_2d", verbose=True).to(device) else: n_epochs = 100 val_interval = 10 diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_inpainting.ipynb b/tutorials/generative/2d_ddpm/2d_ddpm_inpainting.ipynb index 21d5e83f..1fc3f9f0 100644 --- a/tutorials/generative/2d_ddpm/2d_ddpm_inpainting.ipynb +++ b/tutorials/generative/2d_ddpm/2d_ddpm_inpainting.ipynb @@ -636,7 +636,7 @@ "use_pretrained = False\n", "\n", "if use_pretrained:\n", - " model = torch.hub.load(\"marksgraham/pretrained_generative_models\", model=\"ddpm_2d\", verbose=True).to(device)\n", + " model = torch.hub.load(\"marksgraham/pretrained_generative_models:v0.2\", model=\"ddpm_2d\", verbose=True).to(device)\n", "else:\n", " n_epochs = 50\n", " val_interval = 5\n", @@ -914,7 +914,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.8.13" } }, "nbformat": 4, diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_inpainting.py b/tutorials/generative/2d_ddpm/2d_ddpm_inpainting.py index ea6a1f8f..58874039 100644 --- a/tutorials/generative/2d_ddpm/2d_ddpm_inpainting.py +++ b/tutorials/generative/2d_ddpm/2d_ddpm_inpainting.py @@ -6,7 +6,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.14.4 +# jupytext_version: 1.14.1 # kernelspec: # display_name: Python 3 (ipykernel) # language: python @@ -191,7 +191,7 @@ use_pretrained = False if use_pretrained: - model = torch.hub.load("marksgraham/pretrained_generative_models", model="ddpm_2d", verbose=True).to(device) + model = torch.hub.load("marksgraham/pretrained_generative_models:v0.2", model="ddpm_2d", verbose=True).to(device) else: n_epochs = 50 val_interval = 5 diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.ipynb b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.ipynb index 673170a8..40595441 100644 --- a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.ipynb +++ b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.ipynb @@ -42,6 +42,7 @@ "execution_count": 2, "id": "dd62a552", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false } @@ -137,6 +138,7 @@ "execution_count": 3, "id": "8fc58c80", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false } @@ -169,6 +171,7 @@ "execution_count": 4, "id": "ad5a1948", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false } @@ -194,6 +197,7 @@ "execution_count": 5, "id": "65e1c200", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false } @@ -232,6 +236,7 @@ "execution_count": 6, "id": "e2f9bebd", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false } @@ -271,6 +276,7 @@ "execution_count": 7, "id": "938318c2", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false } @@ -320,6 +326,7 @@ "execution_count": 8, "id": "b698f4f8", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false } @@ -372,6 +379,7 @@ "execution_count": 9, "id": "2c52e4f4", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false }, @@ -415,6 +423,7 @@ "execution_count": 10, "id": "0f697a13", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false }, @@ -763,7 +772,7 @@ "use_pretrained = False\n", "\n", "if use_pretrained:\n", - " model = torch.hub.load(\"marksgraham/pretrained_generative_models\", model=\"ddpm_2d\", verbose=True).to(device)\n", + " model = torch.hub.load(\"marksgraham/pretrained_generative_models:v0.2\", model=\"ddpm_2d\", verbose=True).to(device)\n", "else:\n", " n_epochs = 75\n", " val_interval = 5\n", @@ -852,6 +861,7 @@ "execution_count": 11, "id": "2cdcda81", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false } @@ -901,6 +911,7 @@ "execution_count": 12, "id": "1427e5d4", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false } @@ -984,7 +995,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.8.13" } }, "nbformat": 4, diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.py b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.py index 0384d33e..2d81ddb6 100644 --- a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.py +++ b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.py @@ -6,7 +6,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.14.4 +# jupytext_version: 1.14.1 # kernelspec: # display_name: Python 3 (ipykernel) # language: python @@ -190,7 +190,7 @@ use_pretrained = False if use_pretrained: - model = torch.hub.load("marksgraham/pretrained_generative_models", model="ddpm_2d", verbose=True).to(device) + model = torch.hub.load("marksgraham/pretrained_generative_models:v0.2", model="ddpm_2d", verbose=True).to(device) else: n_epochs = 75 val_interval = 5 diff --git a/tutorials/generative/distributed_training/ddpm_training_ddp.py b/tutorials/generative/distributed_training/ddpm_training_ddp.py index f111d52a..07fab1b0 100644 --- a/tutorials/generative/distributed_training/ddpm_training_ddp.py +++ b/tutorials/generative/distributed_training/ddpm_training_ddp.py @@ -197,7 +197,7 @@ def main_worker(args): inferer = DiffusionInferer(scheduler) # wrap the model with DistributedDataParallel module - model = DistributedDataParallel(model, device_ids=[device]) + model = DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True) # start a typical PyTorch training best_metric = 10000 From 69ba231c4e31335faf6d182cc4f9a3a3458c0c54 Mon Sep 17 00:00:00 2001 From: OeslleLucena Date: Tue, 24 Jan 2023 17:00:01 +0000 Subject: [PATCH 12/12] README change --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ed2e2fbc..183f1e50 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@

# MONAI Generative Models -Prototyping repo for generative models to be integrated into MONAI core. +Prototyping repository for generative models to be integrated into MONAI core. ## Features * Network architectures: Diffusion Model, Autoencoder-KL, VQ-VAE, (Multi-scale) Patch-GAN discriminator. * Diffusion Model Schedulers: DDPM, DDIM, and PNDM.