From ee427afdb57b20d893405d9b5953f8166c03c03c Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Mon, 13 Mar 2023 13:20:34 +0000 Subject: [PATCH 1/2] Moving DiffusionPrepareBatch Signed-off-by: Eric Kerfoot --- generative/engines/__init__.py | 1 + generative/engines/prepare_batch.py | 90 +++++++++++++++++++ .../mednist_ddpm/bundle/configs/train.yaml | 2 +- .../mednist_ddpm/bundle/scripts/__init__.py | 49 ---------- .../2d_ddpm/2d_ddpm_tutorial_ignite.ipynb | 74 +++------------ .../2d_ddpm/2d_ddpm_tutorial_ignite.py | 50 +---------- 6 files changed, 105 insertions(+), 161 deletions(-) create mode 100644 generative/engines/prepare_batch.py diff --git a/generative/engines/__init__.py b/generative/engines/__init__.py index f76c669d..316012d4 100644 --- a/generative/engines/__init__.py +++ b/generative/engines/__init__.py @@ -12,3 +12,4 @@ from __future__ import annotations from .trainer import AdversarialTrainer +from .prepare_batch import DiffusionPrepareBatch, VPredictionPrepareBatch diff --git a/generative/engines/prepare_batch.py b/generative/engines/prepare_batch.py new file mode 100644 index 00000000..878c4f9d --- /dev/null +++ b/generative/engines/prepare_batch.py @@ -0,0 +1,90 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Mapping, Callable, Iterable, Sequence, Optional, Dict, Union +import torch +import torch.nn as nn +from monai.engines import PrepareBatch, default_prepare_batch + + +class DiffusionPrepareBatch(PrepareBatch): + """ + This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. + + Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and + return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise". + This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided. + + If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition + field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". + + """ + + def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None) -> None: + self.condition_name = condition_name + self.num_train_timesteps = num_train_timesteps + + def get_noise(self, images:torch.Tensor) -> torch.Tensor: + """Returns the noise tensor for input tensor `images`, override this for different noise distributions.""" + return torch.randn_like(images) + + def get_timesteps(self, images:torch.Tensor) -> torch.Tensor: + """Get a timestep, by default this is a random integer between 0 and `self.num_train_timesteps`.""" + return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long() + + def get_target(self, images:torch.Tensor, noise:torch.Tensor, timesteps:torch.Tensor) -> torch.Tensor: + """Return the target for the loss function, this is the `noise` value by default.""" + return noise + + def __call__( + self, + batchdata: Dict[str, torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + non_blocking: bool = False, + **kwargs, + ): + images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs) + noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs) + timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs) + + target = self.get_target(images, noise, timesteps).to(device, non_blocking=non_blocking, **kwargs) + infer_kwargs = {"noise": noise, "timesteps": timesteps} + + if self.condition_name is not None and isinstance(batchdata, Mapping): + infer_kwargs["conditioning"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs) + + # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value + return images, target, (), infer_kwargs + + +class VPredictionPrepareBatch(DiffusionPrepareBatch): + """ + This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. + + Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and + from this compute the velocity using the provided scheduler. This value is used as the target in place of the + noise field itself although the noise is field is in the kwargs under the key "noise". This assumes the inferer + being used in conjunction with this class expects a "noise" parameter to be provided. + + If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition + field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". + + """ + + def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_name: Optional[str] = None) -> None: + super().__init__(num_train_timesteps=num_train_timesteps, condition_name=condition_name) + self.scheduler = scheduler + + def get_target(self, images, noise, timesteps): + return self.scheduler.get_velocity(images, noise, timesteps) + \ No newline at end of file diff --git a/model-zoo/models/mednist_ddpm/bundle/configs/train.yaml b/model-zoo/models/mednist_ddpm/bundle/configs/train.yaml index 739b3c1f..459e23bd 100644 --- a/model-zoo/models/mednist_ddpm/bundle/configs/train.yaml +++ b/model-zoo/models/mednist_ddpm/bundle/configs/train.yaml @@ -91,7 +91,7 @@ optimizer: lr: '@lr' prepare_batch: - _target_: scripts.DiffusionPrepareBatch + _target_: generative.engines.DiffusionPrepareBatch num_train_timesteps: '@num_train_timesteps' val_handlers: diff --git a/model-zoo/models/mednist_ddpm/bundle/scripts/__init__.py b/model-zoo/models/mednist_ddpm/bundle/scripts/__init__.py index 344830d2..c44e4a34 100644 --- a/model-zoo/models/mednist_ddpm/bundle/scripts/__init__.py +++ b/model-zoo/models/mednist_ddpm/bundle/scripts/__init__.py @@ -1,54 +1,5 @@ from __future__ import annotations -from typing import Dict, Mapping, Optional, Union - -import torch -from monai.engines import PrepareBatch, default_prepare_batch - - -class DiffusionPrepareBatch(PrepareBatch): - """ - This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. - - Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and - return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise". - This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided. - - If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition - field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". - - """ - - def __init__(self, num_train_timesteps: int, condition_name: str | None = None) -> None: - self.condition_name = condition_name - self.num_train_timesteps = num_train_timesteps - - def get_noise(self, images: torch.Tensor) -> torch.Tensor: - """Returns the noise tensor for input tensor `images`, override this for different noise distributions.""" - return torch.randn_like(images) - - def get_timesteps(self, images: torch.Tensor) -> torch.Tensor: - return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long() - - def __call__( - self, - batchdata: Dict[str, torch.Tensor], - device: Union[str, torch.device] | None = None, - non_blocking: bool = False, - **kwargs, - ): - images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs) - noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs) - timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs) - - kwargs = {"noise": noise, "timesteps": timesteps} - - if self.condition_name is not None and isinstance(batchdata, Mapping): - kwargs["conditioning"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs) - - # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value - return images, noise, (), kwargs - def inv_metric_cmp_fn(current_metric: float, prev_best: float) -> bool: """ diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.ipynb b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.ipynb index f831108e..53f95818 100644 --- a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.ipynb +++ b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.ipynb @@ -44,6 +44,7 @@ "execution_count": 2, "id": "dd62a552", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false } @@ -112,6 +113,7 @@ "from monai.utils import first, set_determinism\n", "\n", "from generative.inferers import DiffusionInferer\n", + "from generative.engines import DiffusionPrepareBatch\n", "\n", "# TODO: Add right import reference after deployed\n", "from generative.networks.nets import DiffusionModelUNet\n", @@ -139,6 +141,7 @@ "execution_count": 3, "id": "8fc58c80", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false } @@ -171,6 +174,7 @@ "execution_count": 4, "id": "ad5a1948", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false } @@ -196,6 +200,7 @@ "execution_count": 5, "id": "65e1c200", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false } @@ -234,6 +239,7 @@ "execution_count": 6, "id": "e2f9bebd", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false } @@ -273,6 +279,7 @@ "execution_count": 7, "id": "938318c2", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false } @@ -322,6 +329,7 @@ "execution_count": 8, "id": "b698f4f8", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false } @@ -374,6 +382,7 @@ "execution_count": 9, "id": "2c52e4f4", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false }, @@ -402,67 +411,6 @@ "inferer = DiffusionInferer(scheduler)" ] }, - { - "cell_type": "markdown", - "id": "655fa0a2-91f7-45e6-b3f8-259b76fe7e74", - "metadata": {}, - "source": [ - "### Define a class for preparing batches" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "15e46af7-c3e9-409b-ab1f-5884ada2729f", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "\n", - "class DiffusionPrepareBatch(PrepareBatch):\n", - " \"\"\"\n", - " This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.\n", - "\n", - " Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and\n", - " return the image and noise field as the image/target pair plus the noise field the kwargs under the key \"noise\".\n", - " This assumes the inferer being used in conjunction with this class expects a \"noise\" parameter to be provided.\n", - "\n", - " If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition\n", - " field to be passed to the inferer. This will appear in the keyword arguments under the key \"condition\".\n", - "\n", - " \"\"\"\n", - "\n", - " def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None):\n", - " self.condition_name = condition_name\n", - " self.num_train_timesteps = num_train_timesteps\n", - "\n", - " def get_noise(self, images):\n", - " \"\"\"Returns the noise tensor for input tensor `images`, override this for different noise distributions.\"\"\"\n", - " return torch.randn_like(images)\n", - "\n", - " def get_timesteps(self, images):\n", - " return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()\n", - "\n", - " def __call__(\n", - " self,\n", - " batchdata: Dict[str, torch.Tensor],\n", - " device: Optional[Union[str, torch.device]] = None,\n", - " non_blocking: bool = False,\n", - " **kwargs,\n", - " ):\n", - " images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)\n", - " noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)\n", - " timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)\n", - "\n", - " kwargs = {\"noise\": noise, \"timesteps\": timesteps}\n", - "\n", - " if self.condition_name is not None and isinstance(batchdata, Mapping):\n", - " kwargs[\"conditioning\"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs)\n", - "\n", - " # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value\n", - " return images, noise, (), kwargs" - ] - }, { "cell_type": "markdown", "id": "5a316067", @@ -477,6 +425,7 @@ "execution_count": 11, "id": "0f697a13", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false }, @@ -2207,6 +2156,7 @@ "execution_count": 12, "id": "1427e5d4", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false } @@ -2291,7 +2241,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.9.13" } }, "nbformat": 4, diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.py b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.py index bc67a074..24307b90 100644 --- a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.py +++ b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.py @@ -64,6 +64,7 @@ from monai.utils import first, set_determinism from generative.inferers import DiffusionInferer +from generative.engines import DiffusionPrepareBatch # TODO: Add right import reference after deployed from generative.networks.nets import DiffusionModelUNet @@ -183,55 +184,6 @@ optimizer = torch.optim.Adam(params=model.parameters(), lr=2.5e-5) inferer = DiffusionInferer(scheduler) -# %% [markdown] -# ### Define a class for preparing batches - -# %% - - -class DiffusionPrepareBatch(PrepareBatch): - """ - This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. - - Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and - return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise". - This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided. - - If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition - field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". - - """ - - def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None): - self.condition_name = condition_name - self.num_train_timesteps = num_train_timesteps - - def get_noise(self, images): - """Returns the noise tensor for input tensor `images`, override this for different noise distributions.""" - return torch.randn_like(images) - - def get_timesteps(self, images): - return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long() - - def __call__( - self, - batchdata: Dict[str, torch.Tensor], - device: Optional[Union[str, torch.device]] = None, - non_blocking: bool = False, - **kwargs, - ): - images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs) - noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs) - timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs) - - kwargs = {"noise": noise, "timesteps": timesteps} - - if self.condition_name is not None and isinstance(batchdata, Mapping): - kwargs["conditioning"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs) - - # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value - return images, noise, (), kwargs - # %% [markdown] # ### Model training From ff5497a1b5e01987d6cc25dcf999d2ca1fcffde7 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Mon, 13 Mar 2023 14:10:28 +0000 Subject: [PATCH 2/2] Fixes Signed-off-by: Eric Kerfoot --- generative/engines/__init__.py | 2 +- generative/engines/prepare_batch.py | 14 ++++++++------ generative/networks/nets/patchgan_discriminator.py | 1 - tests/min_tests.py | 1 - tests/runner.py | 1 - tests/test_diffusion_inferer.py | 1 - tests/test_patch_gan.py | 1 - tests/utils.py | 1 - tutorials/generative/2d_ldm/2d_ldm_tutorial.py | 1 - .../2d_vqvae_transformer_tutorial.py | 4 +--- .../distributed_training/ddpm_training_ddp.py | 1 - 11 files changed, 10 insertions(+), 18 deletions(-) diff --git a/generative/engines/__init__.py b/generative/engines/__init__.py index 316012d4..db22bc23 100644 --- a/generative/engines/__init__.py +++ b/generative/engines/__init__.py @@ -11,5 +11,5 @@ from __future__ import annotations -from .trainer import AdversarialTrainer from .prepare_batch import DiffusionPrepareBatch, VPredictionPrepareBatch +from .trainer import AdversarialTrainer diff --git a/generative/engines/prepare_batch.py b/generative/engines/prepare_batch.py index 878c4f9d..4f3693a5 100644 --- a/generative/engines/prepare_batch.py +++ b/generative/engines/prepare_batch.py @@ -11,7 +11,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Mapping, Callable, Iterable, Sequence, Optional, Dict, Union +from typing import Dict, Mapping, Optional, Union + import torch import torch.nn as nn from monai.engines import PrepareBatch, default_prepare_batch @@ -34,15 +35,15 @@ def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = Non self.condition_name = condition_name self.num_train_timesteps = num_train_timesteps - def get_noise(self, images:torch.Tensor) -> torch.Tensor: + def get_noise(self, images: torch.Tensor) -> torch.Tensor: """Returns the noise tensor for input tensor `images`, override this for different noise distributions.""" return torch.randn_like(images) - def get_timesteps(self, images:torch.Tensor) -> torch.Tensor: + def get_timesteps(self, images: torch.Tensor) -> torch.Tensor: """Get a timestep, by default this is a random integer between 0 and `self.num_train_timesteps`.""" return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long() - def get_target(self, images:torch.Tensor, noise:torch.Tensor, timesteps:torch.Tensor) -> torch.Tensor: + def get_target(self, images: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: """Return the target for the loss function, this is the `noise` value by default.""" return noise @@ -61,7 +62,9 @@ def __call__( infer_kwargs = {"noise": noise, "timesteps": timesteps} if self.condition_name is not None and isinstance(batchdata, Mapping): - infer_kwargs["conditioning"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs) + infer_kwargs["conditioning"] = batchdata[self.condition_name].to( + device, non_blocking=non_blocking, **kwargs + ) # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value return images, target, (), infer_kwargs @@ -87,4 +90,3 @@ def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_nam def get_target(self, images, noise, timesteps): return self.scheduler.get_velocity(images, noise, timesteps) - \ No newline at end of file diff --git a/generative/networks/nets/patchgan_discriminator.py b/generative/networks/nets/patchgan_discriminator.py index b5a98c88..bf09b743 100644 --- a/generative/networks/nets/patchgan_discriminator.py +++ b/generative/networks/nets/patchgan_discriminator.py @@ -154,7 +154,6 @@ def __init__( dropout: float | tuple = 0.0, last_conv_kernel_size: int | None = None, ) -> None: - super().__init__() self.num_layers_d = num_layers_d self.num_channels = num_channels diff --git a/tests/min_tests.py b/tests/min_tests.py index 1bc9eed7..15bf3bfd 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -54,7 +54,6 @@ def run_testsuit(): if __name__ == "__main__": - # testing import submodules from monai.utils.module import load_submodules diff --git a/tests/runner.py b/tests/runner.py index 96a1d4a5..7a7cc9f2 100644 --- a/tests/runner.py +++ b/tests/runner.py @@ -114,7 +114,6 @@ def get_default_pattern(loader): if __name__ == "__main__": - # Parse input arguments args = parse_args() diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py index c450ed3d..6faf0e68 100644 --- a/tests/test_diffusion_inferer.py +++ b/tests/test_diffusion_inferer.py @@ -53,7 +53,6 @@ class TestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_call(self, model_params, input_shape): - model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) diff --git a/tests/test_patch_gan.py b/tests/test_patch_gan.py index 7e8df802..50bab99e 100644 --- a/tests/test_patch_gan.py +++ b/tests/test_patch_gan.py @@ -94,7 +94,6 @@ def test_too_small_shape(self): MultiScalePatchDiscriminator(**TEST_TOO_SMALL_SIZE[0]) def test_script(self): - net = MultiScalePatchDiscriminator( num_d=2, num_layers_d=3, diff --git a/tests/utils.py b/tests/utils.py index 1d5b8e9c..274f6aa9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -576,7 +576,6 @@ def run_process(func, args, kwargs, results): results.put(e) def __call__(self, obj): - if self.skip_timing: return obj diff --git a/tutorials/generative/2d_ldm/2d_ldm_tutorial.py b/tutorials/generative/2d_ldm/2d_ldm_tutorial.py index 7b555345..3eb85258 100644 --- a/tutorials/generative/2d_ldm/2d_ldm_tutorial.py +++ b/tutorials/generative/2d_ldm/2d_ldm_tutorial.py @@ -406,7 +406,6 @@ scheduler.set_timesteps(num_inference_steps=1000) with torch.no_grad(): - z_mu, z_sigma = autoencoderkl.encode(image) z = autoencoderkl.sampling(z_mu, z_sigma) diff --git a/tutorials/generative/2d_vqvae_transformer/2d_vqvae_transformer_tutorial.py b/tutorials/generative/2d_vqvae_transformer/2d_vqvae_transformer_tutorial.py index c1890939..c769fd56 100644 --- a/tutorials/generative/2d_vqvae_transformer/2d_vqvae_transformer_tutorial.py +++ b/tutorials/generative/2d_vqvae_transformer/2d_vqvae_transformer_tutorial.py @@ -341,10 +341,10 @@ # %% [markdown] # First we will define a function to allow us to generate random samples from the transformer. This will allow us to keep track of training progress as well to see how samples look during the training cycle + # %% @torch.no_grad() def generate(net, vqvae_model, starting_tokens, seq_len, **kwargs): - progress_bar = iter(range(seq_len)) latent_seq = starting_tokens.long() @@ -394,7 +394,6 @@ def generate(net, vqvae_model, starting_tokens, seq_len, **kwargs): progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=110) progress_bar.set_description(f"Epoch {epoch}") for step, batch in progress_bar: - images = batch["image"].to(device) # Encode images using vqvae and transformer to 1D sequence quantizations = vqvae_model.index_quantize(images) @@ -428,7 +427,6 @@ def generate(net, vqvae_model, starting_tokens, seq_len, **kwargs): val_loss = 0 with torch.no_grad(): for val_step, batch in enumerate(val_loader, start=1): - images = batch["image"].to(device) # Encode images using vqvae and transformer to 1D sequence quantizations = vqvae_model.index_quantize(images) diff --git a/tutorials/generative/distributed_training/ddpm_training_ddp.py b/tutorials/generative/distributed_training/ddpm_training_ddp.py index 07fab1b0..de0f3734 100644 --- a/tutorials/generative/distributed_training/ddpm_training_ddp.py +++ b/tutorials/generative/distributed_training/ddpm_training_ddp.py @@ -83,7 +83,6 @@ def __init__( num_workers: int = 0, shuffle: bool = False, ) -> None: - if not os.path.isdir(root_dir): raise ValueError("root directory root_dir must be a directory.") self.section = section