Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions generative/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@

from __future__ import annotations

from .prepare_batch import DiffusionPrepareBatch, VPredictionPrepareBatch
from .trainer import AdversarialTrainer
92 changes: 92 additions & 0 deletions generative/engines/prepare_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Dict, Mapping, Optional, Union

import torch
import torch.nn as nn
from monai.engines import PrepareBatch, default_prepare_batch


class DiffusionPrepareBatch(PrepareBatch):
"""
This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.

Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and
return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise".
This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided.

If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition
field to be passed to the inferer. This will appear in the keyword arguments under the key "condition".

"""

def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None) -> None:
self.condition_name = condition_name
self.num_train_timesteps = num_train_timesteps

def get_noise(self, images: torch.Tensor) -> torch.Tensor:
"""Returns the noise tensor for input tensor `images`, override this for different noise distributions."""
return torch.randn_like(images)

def get_timesteps(self, images: torch.Tensor) -> torch.Tensor:
"""Get a timestep, by default this is a random integer between 0 and `self.num_train_timesteps`."""
return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()

def get_target(self, images: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
"""Return the target for the loss function, this is the `noise` value by default."""
return noise

def __call__(
self,
batchdata: Dict[str, torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
**kwargs,
):
images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)
noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)
timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)

target = self.get_target(images, noise, timesteps).to(device, non_blocking=non_blocking, **kwargs)
infer_kwargs = {"noise": noise, "timesteps": timesteps}

if self.condition_name is not None and isinstance(batchdata, Mapping):
infer_kwargs["conditioning"] = batchdata[self.condition_name].to(
device, non_blocking=non_blocking, **kwargs
)

# return input, target, arguments, and keyword arguments where noise is the target and also a keyword value
return images, target, (), infer_kwargs


class VPredictionPrepareBatch(DiffusionPrepareBatch):
"""
This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.

Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and
from this compute the velocity using the provided scheduler. This value is used as the target in place of the
noise field itself although the noise is field is in the kwargs under the key "noise". This assumes the inferer
being used in conjunction with this class expects a "noise" parameter to be provided.

If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition
field to be passed to the inferer. This will appear in the keyword arguments under the key "condition".

"""

def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_name: Optional[str] = None) -> None:
super().__init__(num_train_timesteps=num_train_timesteps, condition_name=condition_name)
self.scheduler = scheduler

def get_target(self, images, noise, timesteps):
return self.scheduler.get_velocity(images, noise, timesteps)
1 change: 0 additions & 1 deletion generative/networks/nets/patchgan_discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def __init__(
dropout: float | tuple = 0.0,
last_conv_kernel_size: int | None = None,
) -> None:

super().__init__()
self.num_layers_d = num_layers_d
self.num_channels = num_channels
Expand Down
2 changes: 1 addition & 1 deletion model-zoo/models/mednist_ddpm/bundle/configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ optimizer:
lr: '@lr'

prepare_batch:
_target_: scripts.DiffusionPrepareBatch
_target_: generative.engines.DiffusionPrepareBatch
num_train_timesteps: '@num_train_timesteps'

val_handlers:
Expand Down
49 changes: 0 additions & 49 deletions model-zoo/models/mednist_ddpm/bundle/scripts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,5 @@
from __future__ import annotations

from typing import Dict, Mapping, Optional, Union

import torch
from monai.engines import PrepareBatch, default_prepare_batch


class DiffusionPrepareBatch(PrepareBatch):
"""
This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.

Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and
return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise".
This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided.

If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition
field to be passed to the inferer. This will appear in the keyword arguments under the key "condition".

"""

def __init__(self, num_train_timesteps: int, condition_name: str | None = None) -> None:
self.condition_name = condition_name
self.num_train_timesteps = num_train_timesteps

def get_noise(self, images: torch.Tensor) -> torch.Tensor:
"""Returns the noise tensor for input tensor `images`, override this for different noise distributions."""
return torch.randn_like(images)

def get_timesteps(self, images: torch.Tensor) -> torch.Tensor:
return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()

def __call__(
self,
batchdata: Dict[str, torch.Tensor],
device: Union[str, torch.device] | None = None,
non_blocking: bool = False,
**kwargs,
):
images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)
noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)
timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)

kwargs = {"noise": noise, "timesteps": timesteps}

if self.condition_name is not None and isinstance(batchdata, Mapping):
kwargs["conditioning"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs)

# return input, target, arguments, and keyword arguments where noise is the target and also a keyword value
return images, noise, (), kwargs


def inv_metric_cmp_fn(current_metric: float, prev_best: float) -> bool:
"""
Expand Down
1 change: 0 additions & 1 deletion tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def run_testsuit():


if __name__ == "__main__":

# testing import submodules
from monai.utils.module import load_submodules

Expand Down
1 change: 0 additions & 1 deletion tests/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def get_default_pattern(loader):


if __name__ == "__main__":

# Parse input arguments
args = parse_args()

Expand Down
1 change: 0 additions & 1 deletion tests/test_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
class TestDiffusionSamplingInferer(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_call(self, model_params, input_shape):

model = DiffusionModelUNet(**model_params)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
Expand Down
1 change: 0 additions & 1 deletion tests/test_patch_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def test_too_small_shape(self):
MultiScalePatchDiscriminator(**TEST_TOO_SMALL_SIZE[0])

def test_script(self):

net = MultiScalePatchDiscriminator(
num_d=2,
num_layers_d=3,
Expand Down
1 change: 0 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,6 @@ def run_process(func, args, kwargs, results):
results.put(e)

def __call__(self, obj):

if self.skip_timing:
return obj

Expand Down
74 changes: 12 additions & 62 deletions tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"execution_count": 2,
"id": "dd62a552",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -112,6 +113,7 @@
"from monai.utils import first, set_determinism\n",
"\n",
"from generative.inferers import DiffusionInferer\n",
"from generative.engines import DiffusionPrepareBatch\n",
"\n",
"# TODO: Add right import reference after deployed\n",
"from generative.networks.nets import DiffusionModelUNet\n",
Expand Down Expand Up @@ -139,6 +141,7 @@
"execution_count": 3,
"id": "8fc58c80",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -171,6 +174,7 @@
"execution_count": 4,
"id": "ad5a1948",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand All @@ -196,6 +200,7 @@
"execution_count": 5,
"id": "65e1c200",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -234,6 +239,7 @@
"execution_count": 6,
"id": "e2f9bebd",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -273,6 +279,7 @@
"execution_count": 7,
"id": "938318c2",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -322,6 +329,7 @@
"execution_count": 8,
"id": "b698f4f8",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -374,6 +382,7 @@
"execution_count": 9,
"id": "2c52e4f4",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
Expand Down Expand Up @@ -402,67 +411,6 @@
"inferer = DiffusionInferer(scheduler)"
]
},
{
"cell_type": "markdown",
"id": "655fa0a2-91f7-45e6-b3f8-259b76fe7e74",
"metadata": {},
"source": [
"### Define a class for preparing batches"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "15e46af7-c3e9-409b-ab1f-5884ada2729f",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"class DiffusionPrepareBatch(PrepareBatch):\n",
" \"\"\"\n",
" This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.\n",
"\n",
" Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and\n",
" return the image and noise field as the image/target pair plus the noise field the kwargs under the key \"noise\".\n",
" This assumes the inferer being used in conjunction with this class expects a \"noise\" parameter to be provided.\n",
"\n",
" If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition\n",
" field to be passed to the inferer. This will appear in the keyword arguments under the key \"condition\".\n",
"\n",
" \"\"\"\n",
"\n",
" def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None):\n",
" self.condition_name = condition_name\n",
" self.num_train_timesteps = num_train_timesteps\n",
"\n",
" def get_noise(self, images):\n",
" \"\"\"Returns the noise tensor for input tensor `images`, override this for different noise distributions.\"\"\"\n",
" return torch.randn_like(images)\n",
"\n",
" def get_timesteps(self, images):\n",
" return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()\n",
"\n",
" def __call__(\n",
" self,\n",
" batchdata: Dict[str, torch.Tensor],\n",
" device: Optional[Union[str, torch.device]] = None,\n",
" non_blocking: bool = False,\n",
" **kwargs,\n",
" ):\n",
" images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)\n",
" noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)\n",
" timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)\n",
"\n",
" kwargs = {\"noise\": noise, \"timesteps\": timesteps}\n",
"\n",
" if self.condition_name is not None and isinstance(batchdata, Mapping):\n",
" kwargs[\"conditioning\"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs)\n",
"\n",
" # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value\n",
" return images, noise, (), kwargs"
]
},
{
"cell_type": "markdown",
"id": "5a316067",
Expand All @@ -477,6 +425,7 @@
"execution_count": 11,
"id": "0f697a13",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
Expand Down Expand Up @@ -2207,6 +2156,7 @@
"execution_count": 12,
"id": "1427e5d4",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -2291,7 +2241,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.9.13"
}
},
"nbformat": 4,
Expand Down
Loading