From 7adb8203263a5fb798b3175736eb913d6ed75455 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Mon, 6 Mar 2023 13:53:28 +0000 Subject: [PATCH 01/14] Adding components and refactoring of schedulers (DDPM only) Signed-off-by: Eric Kerfoot --- generative/networks/schedulers/ddpm.py | 126 ++++++------------- generative/networks/schedulers/scheduler.py | 133 ++++++++++++++++++++ generative/utils/__init__.py | 2 + generative/utils/component_store.py | 82 ++++++++++++ generative/utils/misc.py | 25 ++++ tests/test_scheduler_ddpm.py | 8 ++ 6 files changed, 289 insertions(+), 87 deletions(-) create mode 100644 generative/networks/schedulers/scheduler.py create mode 100644 generative/utils/component_store.py create mode 100644 generative/utils/misc.py diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index 2f25f9f1..2ae139a2 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -35,8 +35,25 @@ import torch import torch.nn as nn +from monai.utils import StrEnum -class DDPMScheduler(nn.Module): +from .scheduler import Scheduler, BetaSchedules +from generative.utils import unsqueeze_right + + +class DDPMVarianceType(StrEnum): + FIXED_SMALL="fixed_small" + FIXED_LARGE="fixed_large" + LEARNED="learned" + LEARNED_RANGE="learned_range" + + +class DDPMPRedictionType(StrEnum): + EPSILON="epsiolon" + SAMPLE="sample" + V_PREDICTION="v_prediction" + +class DDPMScheduler(Scheduler): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. Based on: Ho et al., "Denoising Diffusion Probabilistic Models" @@ -46,9 +63,9 @@ class DDPMScheduler(nn.Module): num_train_timesteps: number of diffusion steps used to train the model. beta_start: the starting `beta` value of inference. beta_end: the final `beta` value. - beta_schedule: {``"linear"``, ``"scaled_linear"``} + beta_schedule: member of BetaSchedules the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. - variance_type: {``"fixed_small"``, ``"fixed_large"``, ``"learned"``, ``"learned_range"``} + variance_type: member of VarianceType options to clip the variance used when adding noise to the denoised sample. clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. prediction_type: {``"epsilon"``, ``"sample"``, ``"v_prediction"``} @@ -63,40 +80,21 @@ def __init__( beta_start: float = 1e-4, beta_end: float = 2e-2, beta_schedule: str = "linear", - variance_type: str = "fixed_small", + variance_type: str = DDPMVarianceType.FIXED_SMALL, clip_sample: bool = True, - prediction_type: str = "epsilon", + prediction_type: str = DDPMPRedictionType.EPSILON, ) -> None: - super().__init__() - self.beta_schedule = beta_schedule - if beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - if prediction_type.lower() not in ["epsilon", "sample", "v_prediction"]: - raise ValueError( - f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" - ) - - self.prediction_type = prediction_type - - self.num_train_timesteps = num_train_timesteps - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.one = torch.tensor(1.0) - + super().__init__(num_train_timesteps, beta_start, beta_end, beta_schedule, prediction_type) + + if variance_type not in DDPMVarianceType.__members__.values(): + raise ValueError("Argument `variance_type` must be a member of `DDPMVarianceType`") + + if prediction_type not in DDPMPRedictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of `DDPMPRedictionType`") + self.clip_sample = clip_sample self.variance_type = variance_type - - # settable values - self.num_inference_steps = None - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + self.prediction_type = prediction_type def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: """ @@ -164,13 +162,13 @@ def _get_variance(self, timestep: int, predicted_variance: torch.Tensor | None = # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep] # hacks - were probably added for training stability - if self.variance_type == "fixed_small": + if self.variance_type == DDPMVarianceType.FIXED_SMALL: variance = torch.clamp(variance, min=1e-20) - elif self.variance_type == "fixed_large": + elif self.variance_type == DDPMVarianceType.FIXED_LARGE: variance = self.betas[timestep] - elif self.variance_type == "learned": + elif self.variance_type == DDPMVarianceType.LEARNED: return predicted_variance - elif self.variance_type == "learned_range": + elif self.variance_type == DDPMVarianceType.LEARNED_RANGE: min_log = variance max_log = self.betas[timestep] frac = (predicted_variance + 1) / 2 @@ -207,11 +205,11 @@ def step( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if self.prediction_type == "epsilon": + if self.prediction_type == DDPMPRedictionType.EPSILON: pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif self.prediction_type == "sample": + elif self.prediction_type == DDPMPRedictionType.SAMPLE: pred_original_sample = model_output - elif self.prediction_type == "v_prediction": + elif self.prediction_type == DDPMPRedictionType.V_PREDICTION: pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output # 3. Clip "predicted x_0" @@ -238,50 +236,4 @@ def step( pred_prev_sample = pred_prev_sample + variance return pred_prev_sample, pred_original_sample - - def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: - """ - Add noise to the original samples. - - Args: - original_samples: original samples - noise: noise to add to samples - timesteps: timesteps tensor indicating the timestep to be computed for each sample. - - Returns: - noisy_samples: sample with added noise - """ - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_cumprod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_cumprod = sqrt_alpha_cumprod.flatten() - while len(sqrt_alpha_cumprod.shape) < len(original_samples.shape): - sqrt_alpha_cumprod = sqrt_alpha_cumprod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - - def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: - # Make sure alphas_cumprod and timestep have same device and dtype as sample - self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) - timesteps = timesteps.to(sample.device) - - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(sample.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample - return velocity + \ No newline at end of file diff --git a/generative/networks/schedulers/scheduler.py b/generative/networks/schedulers/scheduler.py new file mode 100644 index 00000000..f343dd4c --- /dev/null +++ b/generative/networks/schedulers/scheduler.py @@ -0,0 +1,133 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + + +from __future__ import annotations +from typing import Callable +import numpy as np +import torch +import torch.nn as nn + +from generative.utils import ComponentStore, unsqueeze_right + + +BetaSchedules = ComponentStore("BetaSchedules", "Functions to generate beta schedules given start/end values and steps") + + +@BetaSchedules.add_def("linear", "Linear beta schedule") +def _linear(beta_start, beta_end, num_train_timesteps): + return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + + +@BetaSchedules.add_def("scaled_linear", "Scaled linear beta schedule") +def _scaled_linear(beta_start, beta_end, num_train_timesteps): + return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + +@BetaSchedules.add_def("sigmoid", "Sigmoid beta schedule") +def _sigmoid(beta_start, beta_end, num_train_timesteps, sig_range=6): + betas = torch.linspace(-sig_range, sig_range, num_train_timesteps) + return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + + +@BetaSchedules.add_def("cosine", "Cosine beta schedule") +def _cosine(beta_start, beta_end, num_train_timesteps, s=0.008): + x = torch.linspace(0, num_train_timesteps, num_train_timesteps - 1) + alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0.0001, 0.9999) + + +class Scheduler(nn.Module): + """ + Base class for other schedulers based on a beta noise schedule. + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + beta_start: the starting `beta` value of inference. + beta_end: the final `beta` value. + beta_schedule: member of BetaSchedules + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 1e-4, + beta_end: float = 2e-2, + beta_schedule: str = "linear", + prediction_type: str = "epsilon", + ) -> None: + super().__init__() + self.betas = BetaSchedules[beta_schedule](beta_start, beta_end, num_train_timesteps) + + self.num_train_timesteps = num_train_timesteps + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.one = torch.tensor(1.0) + + # settable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + """ + Add noise to the original samples. + + Args: + original_samples: original samples + noise: noise to add to samples + timesteps: timesteps tensor indicating the timestep to be computed for each sample. + + Returns: + noisy_samples: sample with added noise + """ + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_cumprod = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, original_samples.ndim) + sqrt_one_minus_alpha_prod = unsqueeze_right((1 - self.alphas_cumprod[timesteps]) ** 0.5, original_samples.ndim) + + noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, sample.ndim) + sqrt_one_minus_alpha_prod = unsqueeze_right((1 - self.alphas_cumprod[timesteps]) ** 0.5, sample.ndim) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity diff --git a/generative/utils/__init__.py b/generative/utils/__init__.py index be9d721b..bf62a696 100644 --- a/generative/utils/__init__.py +++ b/generative/utils/__init__.py @@ -12,3 +12,5 @@ from __future__ import annotations from .enums import AdversarialIterationEvents, AdversarialKeys +from .component_store import ComponentStore +from .misc import * diff --git a/generative/utils/component_store.py b/generative/utils/component_store.py new file mode 100644 index 00000000..765142e2 --- /dev/null +++ b/generative/utils/component_store.py @@ -0,0 +1,82 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from collections import namedtuple +from keyword import iskeyword +from typing import TypeVar, Callable, Any, Dict, Iterable + +T = TypeVar("T") + + +def is_variable(name): + """Returns True if `name` is a valid Python variable name and also not a keyword.""" + return name.isidentifier() and not iskeyword(name) + + +class ComponentStore: + """ + Represents a storage object for other objects (specifically functions) keyed to a name with a description. + These objects act as global named places for storing components for objects parameterised by component names. + """ + + _Component = namedtuple("Component", ("description", "value")) # internal value pair + + def __init__(self, name: str, description: str) -> None: + self.components: Dict[str, self._Component] = {} + self.name: str = name + self.description: str = description + + self.__doc__ = f"Component Store '{name}': {description}\n\n{self.__doc__ or ''}".strip() + + def add(self, name: str, desc: str, value: T) -> T: + """Store the object `value` under the name `name` with description `desc`.""" + if not is_variable(name): + raise ValueError("Name of component must be valid Python identifier") + + self.components[name] = self._Component(desc, value) + return value + + def add_def(self, name: str, desc: str) -> Callable: + """Returns a decorator which stores the decorated function under `name` with description `desc`.""" + + def deco(func): + """Decorator to add a function to a store.""" + return self.add(name, desc, func) + + return deco + + def __contains__(self, name: str) -> bool: + """Returns True if the given name is stored.""" + return name in self.components + + def __len__(self) -> int: + """Returns the number of stored components.""" + return len(self.components) + + def __iter__(self) -> Iterable: + """Yields name/component pairs.""" + for k, v in self.components.items(): + yield k, v.value + + def __getattr__(self, name: str) -> Any: + """Returns the stored object under the given name.""" + if name in self.components: + return self.components[name].value + else: + return self.__getattribute__(name) + + def __getitem__(self, name: str) -> Any: + """Returns the stored object under the given name.""" + if name in self.components: + return self.components[name].value + else: + raise ValueError(f"Component '{name}' not found") diff --git a/generative/utils/misc.py b/generative/utils/misc.py new file mode 100644 index 00000000..c8a00177 --- /dev/null +++ b/generative/utils/misc.py @@ -0,0 +1,25 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import TypeVar + +T = TypeVar("T") + + +def unsqueeze_right(arr: T, ndim: int) -> T: + """Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" + return arr[(...,) + (None,) * (ndim - arr.ndim)] + + +def unsqueeze_left(arr: T, ndim: int) -> T: + """Preppend 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" + return arr[(None,) * (ndim - arr.ndim)] diff --git a/tests/test_scheduler_ddpm.py b/tests/test_scheduler_ddpm.py index 7e07563e..dcb388a3 100644 --- a/tests/test_scheduler_ddpm.py +++ b/tests/test_scheduler_ddpm.py @@ -54,6 +54,14 @@ def test_step_shape(self, input_param, input_shape, expected_shape): output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) self.assertEqual(output_step[0].shape, expected_shape) self.assertEqual(output_step[1].shape, expected_shape) + + @parameterized.expand(TEST_CASES) + def test_get_velocity_shape(self, input_param, input_shape, expected_shape): + scheduler = DDPMScheduler(**input_param) + sample = torch.randn(input_shape) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],)).long() + velocity = scheduler.get_velocity(sample=sample, noise=sample,timesteps=timesteps) + self.assertEqual(velocity.shape, expected_shape) def test_step_learned(self): for variance_type in ["learned", "learned_range"]: From 301aeef3e47b058b4abe65e3d655dfd7124db975 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Mon, 6 Mar 2023 14:02:59 +0000 Subject: [PATCH 02/14] Fix Signed-off-by: Eric Kerfoot --- generative/networks/schedulers/ddpm.py | 24 ++++++++++++------------ tests/test_scheduler_ddpm.py | 4 ++-- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index 2ae139a2..f60f0ae2 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -42,16 +42,17 @@ class DDPMVarianceType(StrEnum): - FIXED_SMALL="fixed_small" - FIXED_LARGE="fixed_large" - LEARNED="learned" - LEARNED_RANGE="learned_range" - - + FIXED_SMALL = "fixed_small" + FIXED_LARGE = "fixed_large" + LEARNED = "learned" + LEARNED_RANGE = "learned_range" + + class DDPMPRedictionType(StrEnum): - EPSILON="epsiolon" - SAMPLE="sample" - V_PREDICTION="v_prediction" + EPSILON = "epsiolon" + SAMPLE = "sample" + V_PREDICTION = "v_prediction" + class DDPMScheduler(Scheduler): """ @@ -88,10 +89,10 @@ def __init__( if variance_type not in DDPMVarianceType.__members__.values(): raise ValueError("Argument `variance_type` must be a member of `DDPMVarianceType`") - + if prediction_type not in DDPMPRedictionType.__members__.values(): raise ValueError("Argument `prediction_type` must be a member of `DDPMPRedictionType`") - + self.clip_sample = clip_sample self.variance_type = variance_type self.prediction_type = prediction_type @@ -236,4 +237,3 @@ def step( pred_prev_sample = pred_prev_sample + variance return pred_prev_sample, pred_original_sample - \ No newline at end of file diff --git a/tests/test_scheduler_ddpm.py b/tests/test_scheduler_ddpm.py index dcb388a3..01f0b7b4 100644 --- a/tests/test_scheduler_ddpm.py +++ b/tests/test_scheduler_ddpm.py @@ -54,13 +54,13 @@ def test_step_shape(self, input_param, input_shape, expected_shape): output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) self.assertEqual(output_step[0].shape, expected_shape) self.assertEqual(output_step[1].shape, expected_shape) - + @parameterized.expand(TEST_CASES) def test_get_velocity_shape(self, input_param, input_shape, expected_shape): scheduler = DDPMScheduler(**input_param) sample = torch.randn(input_shape) timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],)).long() - velocity = scheduler.get_velocity(sample=sample, noise=sample,timesteps=timesteps) + velocity = scheduler.get_velocity(sample=sample, noise=sample, timesteps=timesteps) self.assertEqual(velocity.shape, expected_shape) def test_step_learned(self): From 6a2c2f7e4277d3c835f451c6fb04db7b7d079d88 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Mon, 13 Mar 2023 14:48:05 +0000 Subject: [PATCH 03/14] Adding tests I forgot to add Signed-off-by: Eric Kerfoot --- tests/test_component_store.py | 72 +++++++++++++++++++++++++++++++++++ tests/test_misc.py | 47 +++++++++++++++++++++++ 2 files changed, 119 insertions(+) create mode 100644 tests/test_component_store.py create mode 100644 tests/test_misc.py diff --git a/tests/test_component_store.py b/tests/test_component_store.py new file mode 100644 index 00000000..c6b43bde --- /dev/null +++ b/tests/test_component_store.py @@ -0,0 +1,72 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from generative.utils import ComponentStore + + +class TestComponentStore(unittest.TestCase): + def setUp(self): + self.cs = ComponentStore("TestStore", "I am a test store, please ignore") + + def test_empty(self): + self.assertEqual(len(self.cs), 0) + self.assertEqual(list(self.cs), []) + + def test_add(self): + test_obj = object() + + self.assertFalse("test_obj" in self.cs) + + self.cs.add("test_obj", "Test object", test_obj) + + self.assertTrue("test_obj" in self.cs) + + self.assertEqual(len(self.cs), 1) + self.assertEqual(list(self.cs), [("test_obj", test_obj)]) + + self.assertEqual(self.cs.test_obj, test_obj) + self.assertEqual(self.cs["test_obj"], test_obj) + + def test_add2(self): + test_obj1 = object() + test_obj2 = object() + + self.cs.add("test_obj1", "Test object", test_obj1) + self.cs.add("test_obj2", "Test object", test_obj2) + + self.assertEqual(len(self.cs), 2) + self.assertTrue("test_obj1" in self.cs) + self.assertTrue("test_obj2" in self.cs) + + def test_add_def(self): + self.assertFalse("test_func" in self.cs) + + @self.cs.add_def("test_func", "Test function") + def test_func(): + return 123 + + self.assertTrue("test_func" in self.cs) + + self.assertEqual(len(self.cs), 1) + self.assertEqual(list(self.cs), [("test_func", test_func)]) + + self.assertEqual(self.cs.test_func, test_func) + self.assertEqual(self.cs["test_func"], test_func) + + # try adding the same function again + self.cs.add_def("test_func", "Test function but with new description")(test_func) + + self.assertEqual(len(self.cs), 1) + self.assertEqual(self.cs.test_func, test_func) diff --git a/tests/test_misc.py b/tests/test_misc.py new file mode 100644 index 00000000..e0625321 --- /dev/null +++ b/tests/test_misc.py @@ -0,0 +1,47 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from generative.utils import unsqueeze_left, unsqueeze_right + +RIGHT_CASES = [(np.random.rand(3, 4), 5, (3, 4, 1, 1, 1)), (torch.rand(3, 4), 5, (3, 4, 1, 1, 1))] + +LEFT_CASES = [(np.random.rand(3, 4), 5, (1, 1, 1, 3, 4)), (torch.rand(3, 4), 5, (1, 1, 1, 3, 4))] + +ALL_CASES = [ + (np.random.rand(3, 4), 2, (3, 4)), + (np.random.rand(3, 4), 0, (3, 4)), + (np.random.rand(3, 4), -1, (3, 4)), + (np.array(3), 4, (1, 1, 1, 1)), + (np.array(3), 0, ()), + (torch.rand(3, 4), 2, (3, 4)), + (torch.rand(3, 4), 0, (3, 4)), + (torch.rand(3, 4), -1, (3, 4)), + (torch.tensor(3), 4, (1, 1, 1, 1)), + (torch.tensor(3), 0, ()), +] + + +class TestUnsqueeze(unittest.TestCase): + @parameterized.expand(RIGHT_CASES + ALL_CASES) + def test_unsqueeze_right(self, arr, ndim, shape): + self.assertEqual(unsqueeze_right(arr, ndim).shape, shape) + + @parameterized.expand(LEFT_CASES + ALL_CASES) + def test_unsqueeze_left(self, arr, ndim, shape): + self.assertEqual(unsqueeze_left(arr, ndim).shape, shape) From ea6d732510aaa4da7c93f97cbfa407c3a38b2ab9 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Fri, 17 Mar 2023 18:02:04 +0000 Subject: [PATCH 04/14] Updates from comments --- generative/networks/schedulers/ddpm.py | 23 +++++----- generative/networks/schedulers/scheduler.py | 51 +++++++++++---------- generative/utils/component_store.py | 9 +++- tests/test_scheduler_ddpm.py | 8 ++-- 4 files changed, 49 insertions(+), 42 deletions(-) diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index f60f0ae2..1e50f447 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -37,7 +37,7 @@ from monai.utils import StrEnum -from .scheduler import Scheduler, BetaSchedules +from .scheduler import Scheduler, NoiseSchedules from generative.utils import unsqueeze_right @@ -48,8 +48,8 @@ class DDPMVarianceType(StrEnum): LEARNED_RANGE = "learned_range" -class DDPMPRedictionType(StrEnum): - EPSILON = "epsiolon" +class DDPMPredictionType(StrEnum): + EPSILON = "epsilon" SAMPLE = "sample" V_PREDICTION = "v_prediction" @@ -78,19 +78,18 @@ class DDPMScheduler(Scheduler): def __init__( self, num_train_timesteps: int = 1000, - beta_start: float = 1e-4, - beta_end: float = 2e-2, - beta_schedule: str = "linear", + schedule: str = "linear_beta", variance_type: str = DDPMVarianceType.FIXED_SMALL, clip_sample: bool = True, - prediction_type: str = DDPMPRedictionType.EPSILON, + prediction_type: str = DDPMPredictionType.EPSILON, + **schedule_args ) -> None: - super().__init__(num_train_timesteps, beta_start, beta_end, beta_schedule, prediction_type) + super().__init__(num_train_timesteps, schedule,**schedule_args) if variance_type not in DDPMVarianceType.__members__.values(): raise ValueError("Argument `variance_type` must be a member of `DDPMVarianceType`") - if prediction_type not in DDPMPRedictionType.__members__.values(): + if prediction_type not in DDPMPredictionType.__members__.values(): raise ValueError("Argument `prediction_type` must be a member of `DDPMPRedictionType`") self.clip_sample = clip_sample @@ -206,11 +205,11 @@ def step( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if self.prediction_type == DDPMPRedictionType.EPSILON: + if self.prediction_type == DDPMPredictionType.EPSILON: pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif self.prediction_type == DDPMPRedictionType.SAMPLE: + elif self.prediction_type == DDPMPredictionType.SAMPLE: pred_original_sample = model_output - elif self.prediction_type == DDPMPRedictionType.V_PREDICTION: + elif self.prediction_type == DDPMPredictionType.V_PREDICTION: pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output # 3. Clip "predicted x_0" diff --git a/generative/networks/schedulers/scheduler.py b/generative/networks/schedulers/scheduler.py index f343dd4c..610aa53e 100644 --- a/generative/networks/schedulers/scheduler.py +++ b/generative/networks/schedulers/scheduler.py @@ -39,27 +39,29 @@ from generative.utils import ComponentStore, unsqueeze_right -BetaSchedules = ComponentStore("BetaSchedules", "Functions to generate beta schedules given start/end values and steps") +NoiseSchedules = ComponentStore("NoiseSchedules", "Functions to generate noise schedules") -@BetaSchedules.add_def("linear", "Linear beta schedule") -def _linear(beta_start, beta_end, num_train_timesteps): +@NoiseSchedules.add_def("linear_beta", "Linear beta schedule, args: num_train_timesteps, beta_start, beta_end") +def _linear_beta(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) -@BetaSchedules.add_def("scaled_linear", "Scaled linear beta schedule") -def _scaled_linear(beta_start, beta_end, num_train_timesteps): +@NoiseSchedules.add_def( + "scaled_linear_beta", "Scaled linear beta schedule, args: num_train_timesteps, beta_start, beta_end" +) +def _scaled_linear_beta(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 -@BetaSchedules.add_def("sigmoid", "Sigmoid beta schedule") -def _sigmoid(beta_start, beta_end, num_train_timesteps, sig_range=6): +@NoiseSchedules.add_def("sigmoid_beta", "Sigmoid beta schedule, args: num_train_timesteps, beta_start, beta_end") +def _sigmoid_beta(num_train_timesteps, beta_start=1e-4, beta_end=2e-2, sig_range=6): betas = torch.linspace(-sig_range, sig_range, num_train_timesteps) return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start -@BetaSchedules.add_def("cosine", "Cosine beta schedule") -def _cosine(beta_start, beta_end, num_train_timesteps, s=0.008): +@NoiseSchedules.add_def("cosine_beta", "Cosine beta schedule, args: num_train_timesteps") +def _cosine_beta(num_train_timesteps, s=0.008): x = torch.linspace(0, num_train_timesteps, num_train_timesteps - 1) alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] @@ -73,31 +75,30 @@ class Scheduler(nn.Module): Args: num_train_timesteps: number of diffusion steps used to train the model. - beta_start: the starting `beta` value of inference. - beta_end: the final `beta` value. - beta_schedule: member of BetaSchedules - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. + schedule: member of NoiseSchedules, + a named function returning the beta tensor or (betas, alphas, alphas_cumprod) triple + schedule_args: arguments to pass to the schedule function """ - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 1e-4, - beta_end: float = 2e-2, - beta_schedule: str = "linear", - prediction_type: str = "epsilon", - ) -> None: + def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear", **schedule_args) -> None: super().__init__() - self.betas = BetaSchedules[beta_schedule](beta_start, beta_end, num_train_timesteps) + schedule_args["num_train_timesteps"] = num_train_timesteps + noise = NoiseSchedules[schedule](**schedule_args) + + # set betas, alphas, alphas_cumprod based off return value from noise function + if isinstance(noise, tuple): + self.betas, self.alphas, self.alphas_cumprod = noise + else: + self.betas = noise + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.num_train_timesteps = num_train_timesteps - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) # settable values self.num_inference_steps = None - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1) def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: """ diff --git a/generative/utils/component_store.py b/generative/utils/component_store.py index 765142e2..a088f6a2 100644 --- a/generative/utils/component_store.py +++ b/generative/utils/component_store.py @@ -35,7 +35,7 @@ def __init__(self, name: str, description: str) -> None: self.name: str = name self.description: str = description - self.__doc__ = f"Component Store '{name}': {description}\n\n{self.__doc__ or ''}".strip() + self.__doc__ = f"Component Store '{name}': {description}\n{self.__doc__ or ''}".strip() def add(self, name: str, desc: str, value: T) -> T: """Store the object `value` under the name `name` with description `desc`.""" @@ -67,6 +67,13 @@ def __iter__(self) -> Iterable: for k, v in self.components.items(): yield k, v.value + def __str__(self): + result = f"Component Store '{self.name}': {self.description}\nAvailable components:" + for k, v in self.components.items(): + result += f"\n {k}: {v.description}" + + return result + def __getattr__(self, name: str) -> Any: """Returns the stored object under the given name.""" if name in self.components: diff --git a/tests/test_scheduler_ddpm.py b/tests/test_scheduler_ddpm.py index 01f0b7b4..835537fe 100644 --- a/tests/test_scheduler_ddpm.py +++ b/tests/test_scheduler_ddpm.py @@ -19,17 +19,17 @@ from generative.networks.schedulers import DDPMScheduler TEST_2D_CASE = [] -for beta_schedule in ["linear", "scaled_linear"]: +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: for variance_type in ["fixed_small", "fixed_large"]: TEST_2D_CASE.append( - [{"beta_schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16), (2, 6, 16, 16)] + [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16), (2, 6, 16, 16)] ) TEST_3D_CASE = [] -for beta_schedule in ["linear", "scaled_linear"]: +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: for variance_type in ["fixed_small", "fixed_large"]: TEST_3D_CASE.append( - [{"beta_schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)] + [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)] ) TEST_CASES = TEST_2D_CASE + TEST_3D_CASE From 78bb9facfaa00db46905023d7d43427593f57201 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Thu, 30 Mar 2023 18:33:37 +0100 Subject: [PATCH 05/14] Updates to other schedulers Signed-off-by: Eric Kerfoot --- generative/networks/schedulers/ddim.py | 112 ++++++-------------- generative/networks/schedulers/ddpm.py | 3 +- generative/networks/schedulers/pndm.py | 70 +++--------- generative/networks/schedulers/scheduler.py | 73 ++++++++++--- tests/test_scheduler_ddim.py | 8 +- tests/test_scheduler_pndm.py | 8 +- 6 files changed, 114 insertions(+), 160 deletions(-) diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index 3f4ac0c6..ec819236 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -35,8 +35,18 @@ import torch import torch.nn as nn +from monai.utils import StrEnum -class DDIMScheduler(nn.Module): +from .scheduler import Scheduler + + +class DDIMPredictionType(StrEnum): + EPSILON = "epsilon" + SAMPLE = "sample" + V_PREDICTION = "v_prediction" + + +class DDIMScheduler(Scheduler): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. Based on: Song et al. "Denoising Diffusion @@ -64,36 +74,20 @@ class DDIMScheduler(nn.Module): def __init__( self, num_train_timesteps: int = 1000, - beta_start: float = 1e-4, - beta_end: float = 2e-2, - beta_schedule: str = "linear", + schedule: str = "linear_beta", clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, - prediction_type: str = "epsilon", + prediction_type: str = DDIMPredictionType.EPSILON, + **schedule_args ) -> None: - super().__init__() - self.beta_schedule = beta_schedule - if beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + super().__init__(num_train_timesteps,schedule,**schedule_args) - if prediction_type.lower() not in ["epsilon", "sample", "v_prediction"]: - raise ValueError( - f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" - ) + if prediction_type not in DDIMPredictionType.__members__.values(): + raise ValueError(f"Argument `prediction_type` must be a member of DDIMPredictionType") self.prediction_type = prediction_type - self.num_train_timesteps = num_train_timesteps - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - + # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or @@ -103,13 +97,13 @@ def __init__( # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64)) + self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64)) self.clip_sample = clip_sample self.steps_offset = steps_offset # default the number of inference timesteps to the number of train steps - self.set_timesteps(num_train_timesteps) + self.set_timesteps(self.num_train_timesteps) def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: """ @@ -190,13 +184,13 @@ def step( # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - if self.prediction_type == "epsilon": - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + if self.prediction_type == DDIMPredictionType.EPSILON: + pred_original_sample = (sample - (beta_prod_t ** 0.5) * model_output) / (alpha_prod_t ** 0.5) pred_epsilon = model_output - elif self.prediction_type == "sample": + elif self.prediction_type == DDIMPredictionType.SAMPLE: pred_original_sample = model_output - pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) - elif self.prediction_type == "v_prediction": + pred_epsilon = (sample - (alpha_prod_t ** 0.5) * pred_original_sample) / (beta_prod_t ** 0.5) + elif self.prediction_type == DDIMPredictionType.V_PREDICTION: pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample @@ -207,67 +201,21 @@ def step( # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(timestep, prev_timestep) - std_dev_t = eta * variance ** (0.5) + std_dev_t = eta * variance ** 0.5 # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * pred_epsilon # 7. compute x_t-1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + pred_prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction if eta > 0: # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 device = model_output.device if torch.is_tensor(model_output) else "cpu" noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) - variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise + variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise pred_prev_sample = pred_prev_sample + variance return pred_prev_sample, pred_original_sample - - def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: - """ - Add noise to the original samples. - - Args: - original_samples: original samples - noise: noise to add to samples - timesteps: timesteps tensor indicating the timestep to be computed for each sample. - - Returns: - noisy_samples: sample with added noise - """ - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_cumprod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_cumprod = sqrt_alpha_cumprod.flatten() - while len(sqrt_alpha_cumprod.shape) < len(original_samples.shape): - sqrt_alpha_cumprod = sqrt_alpha_cumprod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - - def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: - # Make sure alphas_cumprod and timestep have same device and dtype as sample - self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) - timesteps = timesteps.to(sample.device) - - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(sample.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample - return velocity + \ No newline at end of file diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index 1e50f447..ed025103 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -37,8 +37,7 @@ from monai.utils import StrEnum -from .scheduler import Scheduler, NoiseSchedules -from generative.utils import unsqueeze_right +from .scheduler import Scheduler class DDPMVarianceType(StrEnum): diff --git a/generative/networks/schedulers/pndm.py b/generative/networks/schedulers/pndm.py index 0a1f2018..83cb323c 100644 --- a/generative/networks/schedulers/pndm.py +++ b/generative/networks/schedulers/pndm.py @@ -37,8 +37,17 @@ import torch import torch.nn as nn +from monai.utils import StrEnum -class PNDMScheduler(nn.Module): +from .scheduler import Scheduler + + +class PNDMPredictionType(StrEnum): + EPSILON = "epsilon" + V_PREDICTION = "v_prediction" + + +class PNDMScheduler(Scheduler): """ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, namely Runge-Kutta method and a linear multi-step method. Based on: Liu et al., @@ -70,33 +79,19 @@ class PNDMScheduler(nn.Module): def __init__( self, num_train_timesteps: int = 1000, - beta_start: float = 1e-4, - beta_end: float = 2e-2, - beta_schedule: str = "linear", + schedule: str = "linear_beta", skip_prk_steps: bool = False, set_alpha_to_one: bool = False, - prediction_type: str = "epsilon", + prediction_type: str = PNDMPredictionType.EPSILON, steps_offset: int = 0, + **schedule_args ) -> None: - super().__init__() - self.beta_schedule = beta_schedule - if beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + super().__init__(num_train_timesteps,schedule,**schedule_args) - if prediction_type.lower() not in ["epsilon", "v_prediction"]: - raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon` or `v_prediction`") + if prediction_type not in PNDMPredictionType.__members__.values(): + raise ValueError(f"Argument `prediction_type` must be a member of PNDMPredictionType") self.prediction_type = prediction_type - self.num_train_timesteps = num_train_timesteps - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] @@ -117,8 +112,6 @@ def __init__( self.cur_sample = None self.ets = [] - self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() - # default the number of inference timesteps to the number of train steps self.set_timesteps(num_train_timesteps) @@ -302,7 +295,7 @@ def _get_prev_sample(self, sample: torch.Tensor, timestep: int, prev_timestep: i beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev - if self.prediction_type == "v_prediction": + if self.prediction_type == PNDMPredictionType.V_PREDICTION: model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample # corresponds to (α_(t−δ) - α_t) divided by @@ -322,32 +315,3 @@ def _get_prev_sample(self, sample: torch.Tensor, timestep: int, prev_timestep: i ) return prev_sample - - def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: - """ - Add noise to the original samples. - - Args: - original_samples: original samples - noise: noise to add to samples - timesteps: timesteps tensor indicating the timestep to be computed for each sample. - - Returns: - noisy_samples: sample with added noise - """ - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_cumprod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_cumprod = sqrt_alpha_cumprod.flatten() - while len(sqrt_alpha_cumprod.shape) < len(original_samples.shape): - sqrt_alpha_cumprod = sqrt_alpha_cumprod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples diff --git a/generative/networks/schedulers/scheduler.py b/generative/networks/schedulers/scheduler.py index 610aa53e..624424be 100644 --- a/generative/networks/schedulers/scheduler.py +++ b/generative/networks/schedulers/scheduler.py @@ -42,31 +42,74 @@ NoiseSchedules = ComponentStore("NoiseSchedules", "Functions to generate noise schedules") -@NoiseSchedules.add_def("linear_beta", "Linear beta schedule, args: num_train_timesteps, beta_start, beta_end") +@NoiseSchedules.add_def("linear_beta", "Linear beta schedule") def _linear_beta(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): + """ + Linear beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + + Returns: + betas: beta schedule tensor + """ return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) -@NoiseSchedules.add_def( - "scaled_linear_beta", "Scaled linear beta schedule, args: num_train_timesteps, beta_start, beta_end" -) +@NoiseSchedules.add_def("scaled_linear_beta", "Scaled linear beta schedule") def _scaled_linear_beta(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): + """ + Scaled linear beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + + Returns: + betas: beta schedule tensor + """ return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 -@NoiseSchedules.add_def("sigmoid_beta", "Sigmoid beta schedule, args: num_train_timesteps, beta_start, beta_end") +@NoiseSchedules.add_def("sigmoid_beta", "Sigmoid beta schedule") def _sigmoid_beta(num_train_timesteps, beta_start=1e-4, beta_end=2e-2, sig_range=6): + """ + Sigmoid beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + sig_range: pos/neg range of sigmoid input, default 6 + + Returns: + betas: beta schedule tensor + """ betas = torch.linspace(-sig_range, sig_range, num_train_timesteps) return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start -@NoiseSchedules.add_def("cosine_beta", "Cosine beta schedule, args: num_train_timesteps") -def _cosine_beta(num_train_timesteps, s=0.008): - x = torch.linspace(0, num_train_timesteps, num_train_timesteps - 1) +@NoiseSchedules.add_def("cosine_beta", "Cosine beta schedule") +def _cosine_beta(num_train_timesteps, s=8e-3): + """ + Cosine noise schedule returning. + + Args: + num_train_timesteps: number of timesteps + s: smoothing factor, default 8e-3 + + Returns: + (betas, alphas, alpha_cumprod) values + """ + x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1) alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 - alphas_cumprod = alphas_cumprod / alphas_cumprod[0] - betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) - return torch.clip(betas, 0.0001, 0.9999) + alphas_cumprod /= alphas_cumprod[0].item() + alphas = torch.clip(alphas_cumprod[1:] / alphas_cumprod[:-1], 0.0001, 0.9999) + betas = 1.0 - alphas + return betas, alphas, alphas_cumprod[:-1] class Scheduler(nn.Module): @@ -83,13 +126,13 @@ class Scheduler(nn.Module): def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear", **schedule_args) -> None: super().__init__() schedule_args["num_train_timesteps"] = num_train_timesteps - noise = NoiseSchedules[schedule](**schedule_args) + noise_sched = NoiseSchedules[schedule](**schedule_args) # set betas, alphas, alphas_cumprod based off return value from noise function - if isinstance(noise, tuple): - self.betas, self.alphas, self.alphas_cumprod = noise + if isinstance(noise_sched, tuple): + self.betas, self.alphas, self.alphas_cumprod = noise_sched else: - self.betas = noise + self.betas = noise_sched self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) diff --git a/tests/test_scheduler_ddim.py b/tests/test_scheduler_ddim.py index 67d773fe..3c64b42c 100644 --- a/tests/test_scheduler_ddim.py +++ b/tests/test_scheduler_ddim.py @@ -19,12 +19,12 @@ from generative.networks.schedulers import DDIMScheduler TEST_2D_CASE = [] -for beta_schedule in ["linear", "scaled_linear"]: - TEST_2D_CASE.append([{"beta_schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) TEST_3D_CASE = [] -for beta_schedule in ["linear", "scaled_linear"]: - TEST_3D_CASE.append([{"beta_schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) TEST_CASES = TEST_2D_CASE + TEST_3D_CASE diff --git a/tests/test_scheduler_pndm.py b/tests/test_scheduler_pndm.py index ee0cda29..4e0dbb97 100644 --- a/tests/test_scheduler_pndm.py +++ b/tests/test_scheduler_pndm.py @@ -19,12 +19,12 @@ from generative.networks.schedulers import PNDMScheduler TEST_2D_CASE = [] -for beta_schedule in ["linear", "scaled_linear"]: - TEST_2D_CASE.append([{"beta_schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) TEST_3D_CASE = [] -for beta_schedule in ["linear", "scaled_linear"]: - TEST_3D_CASE.append([{"beta_schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) +for beta_schedule in ["linear_beta", "scaled_linear_beta"]: + TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) TEST_CASES = TEST_2D_CASE + TEST_3D_CASE From 3b4e35fc835332aa9fa6987d88fda8e9a9624e27 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 4 Apr 2023 13:21:17 +0100 Subject: [PATCH 06/14] Update Signed-off-by: Eric Kerfoot --- generative/networks/schedulers/ddim.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index da250d9c..fa2c6268 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -257,13 +257,13 @@ def reversed_step( # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - if self.prediction_type == "epsilon": + if self.prediction_type == DDIMPredictionType.EPSILON: pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output - elif self.prediction_type == "sample": + elif self.prediction_type == DDIMPredictionType.SAMPLE: pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) - elif self.prediction_type == "v_prediction": + elif self.prediction_type == DDIMPredictionType.V_PREDICTION: pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample From c710306a2dc1e5243fe2b7af8f090b0a1d586f18 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 4 Apr 2023 13:38:05 +0100 Subject: [PATCH 07/14] Tutorials updates Signed-off-by: Eric Kerfoot --- tutorials/generative/2d_ldm/2d_ldm_tutorial.ipynb | 2 +- tutorials/generative/2d_ldm/2d_ldm_tutorial.py | 2 +- .../2d_stable_diffusion_v2_super_resolution.ipynb | 4 ++-- .../2d_stable_diffusion_v2_super_resolution.py | 4 ++-- tutorials/generative/3d_ddpm/3d_ddpm_tutorial.ipynb | 4 ++-- tutorials/generative/3d_ddpm/3d_ddpm_tutorial.py | 4 ++-- tutorials/generative/3d_ldm/3d_ldm_tutorial.ipynb | 2 +- tutorials/generative/3d_ldm/3d_ldm_tutorial.py | 2 +- 8 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tutorials/generative/2d_ldm/2d_ldm_tutorial.ipynb b/tutorials/generative/2d_ldm/2d_ldm_tutorial.ipynb index cb4bd4b4..9a09dc95 100644 --- a/tutorials/generative/2d_ldm/2d_ldm_tutorial.ipynb +++ b/tutorials/generative/2d_ldm/2d_ldm_tutorial.ipynb @@ -851,7 +851,7 @@ " num_head_channels=(0, 256, 512),\n", ")\n", "\n", - "scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule=\"linear\", beta_start=0.0015, beta_end=0.0195)" + "scheduler = DDPMScheduler(num_train_timesteps=1000, schedule=\"linear_beta\", beta_start=0.0015, beta_end=0.0195)" ] }, { diff --git a/tutorials/generative/2d_ldm/2d_ldm_tutorial.py b/tutorials/generative/2d_ldm/2d_ldm_tutorial.py index 9face129..681c0a1e 100644 --- a/tutorials/generative/2d_ldm/2d_ldm_tutorial.py +++ b/tutorials/generative/2d_ldm/2d_ldm_tutorial.py @@ -310,7 +310,7 @@ num_head_channels=(0, 256, 512), ) -scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", beta_start=0.0015, beta_end=0.0195) +scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="linear_beta", beta_start=0.0015, beta_end=0.0195) # - # ### Scaling factor diff --git a/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution.ipynb b/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution.ipynb index bad152f6..c6e58254 100644 --- a/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution.ipynb +++ b/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution.ipynb @@ -881,7 +881,7 @@ ")\n", "unet = unet.to(device)\n", "\n", - "scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule=\"linear\", beta_start=0.0015, beta_end=0.0195)" + "scheduler = DDPMScheduler(num_train_timesteps=1000, schedule=\"linear_beta\", beta_start=0.0015, beta_end=0.0195)" ] }, { @@ -899,7 +899,7 @@ "metadata": {}, "outputs": [], "source": [ - "low_res_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule=\"linear\", beta_start=0.0015, beta_end=0.0195)\n", + "low_res_scheduler = DDPMScheduler(num_train_timesteps=1000, schedule=\"linear_beta\", beta_start=0.0015, beta_end=0.0195)\n", "\n", "max_noise_level = 350" ] diff --git a/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution.py b/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution.py index 1234935d..82369fd7 100644 --- a/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution.py +++ b/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution.py @@ -323,13 +323,13 @@ ) unet = unet.to(device) -scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", beta_start=0.0015, beta_end=0.0195) +scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="linear_beta", beta_start=0.0015, beta_end=0.0195) # %% [markdown] # As mentioned, we will use the conditioned augmentation (introduced in [2] section 3 and used on Stable Diffusion Upscalers and Imagen Video [3] Section 2.5) as it has been shown critical for cascaded diffusion models, as well for super-resolution tasks. For this, we apply Gaussian noise augmentation to the low-resolution images. We will use a scheduler `low_res_scheduler` to add this noise, with the `t` step defining the signal-to-noise ratio and use the `t` value to condition the diffusion model (inputted using `class_labels` argument). # %% -low_res_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", beta_start=0.0015, beta_end=0.0195) +low_res_scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="linear_beta", beta_start=0.0015, beta_end=0.0195) max_noise_level = 350 diff --git a/tutorials/generative/3d_ddpm/3d_ddpm_tutorial.ipynb b/tutorials/generative/3d_ddpm/3d_ddpm_tutorial.ipynb index e5d0f2fb..1174e567 100644 --- a/tutorials/generative/3d_ddpm/3d_ddpm_tutorial.ipynb +++ b/tutorials/generative/3d_ddpm/3d_ddpm_tutorial.ipynb @@ -357,7 +357,7 @@ "metadata": {}, "outputs": [], "source": [ - "scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule=\"scaled_linear\", beta_start=0.0005, beta_end=0.0195)" + "scheduler = DDPMScheduler(num_train_timesteps=1000, schedule=\"scaled_linear_beta\", beta_start=0.0005, beta_end=0.0195)" ] }, { @@ -901,7 +901,7 @@ ], "source": [ "scheduler_ddim = DDIMScheduler(\n", - " num_train_timesteps=1000, beta_schedule=\"scaled_linear\", beta_start=0.0005, beta_end=0.0195, clip_sample=False\n", + " num_train_timesteps=1000, schedule=\"scaled_linear_beta\", beta_start=0.0005, beta_end=0.0195, clip_sample=False\n", ")\n", "\n", "scheduler_ddim.set_timesteps(num_inference_steps=250)\n", diff --git a/tutorials/generative/3d_ddpm/3d_ddpm_tutorial.py b/tutorials/generative/3d_ddpm/3d_ddpm_tutorial.py index 7ea85756..527b96d6 100644 --- a/tutorials/generative/3d_ddpm/3d_ddpm_tutorial.py +++ b/tutorials/generative/3d_ddpm/3d_ddpm_tutorial.py @@ -164,7 +164,7 @@ # Together with our U-net, we need to define the Noise Scheduler for the diffusion model. This scheduler is responsible for defining the amount of noise that should be added in each timestep `t` of the diffusion model's Markov chain. Besides that, it has the operations to perform the reverse process, which will remove the noise of the images (a.k.a. denoising process). In this case, we are using a `DDPMScheduler`. Here we are using 1000 timesteps and a `scaled_linear` profile for the beta values (proposed in [Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models"](https://arxiv.org/abs/2112.10752)). This profile had better results than the `linear, proposed in the original DDPM's paper. In `beta_start` and `beta_end`, we define the limits for the beta values. These are important to determine how accentuated is the addition of noise in the image. # %% -scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="scaled_linear", beta_start=0.0005, beta_end=0.0195) +scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="scaled_linear_beta", beta_start=0.0005, beta_end=0.0195) # %% plt.plot(scheduler.alphas_cumprod.cpu(), color=(2 / 255, 163 / 255, 163 / 255), linewidth=2) @@ -310,7 +310,7 @@ # %% scheduler_ddim = DDIMScheduler( - num_train_timesteps=1000, beta_schedule="scaled_linear", beta_start=0.0005, beta_end=0.0195, clip_sample=False + num_train_timesteps=1000, schedule="scaled_linear_beta", beta_start=0.0005, beta_end=0.0195, clip_sample=False ) scheduler_ddim.set_timesteps(num_inference_steps=250) diff --git a/tutorials/generative/3d_ldm/3d_ldm_tutorial.ipynb b/tutorials/generative/3d_ldm/3d_ldm_tutorial.ipynb index 48e96ffe..5e07974f 100644 --- a/tutorials/generative/3d_ldm/3d_ldm_tutorial.ipynb +++ b/tutorials/generative/3d_ldm/3d_ldm_tutorial.ipynb @@ -741,7 +741,7 @@ "unet.to(device)\n", "\n", "\n", - "scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule=\"scaled_linear\", beta_start=0.0015, beta_end=0.0195)" + "scheduler = DDPMScheduler(num_train_timesteps=1000, schedule=\"scaled_linear_beta\", beta_start=0.0015, beta_end=0.0195)" ] }, { diff --git a/tutorials/generative/3d_ldm/3d_ldm_tutorial.py b/tutorials/generative/3d_ldm/3d_ldm_tutorial.py index 0cf6a302..6ea8cfb0 100644 --- a/tutorials/generative/3d_ldm/3d_ldm_tutorial.py +++ b/tutorials/generative/3d_ldm/3d_ldm_tutorial.py @@ -308,7 +308,7 @@ def KL_loss(z_mu, z_sigma): unet.to(device) -scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="scaled_linear", beta_start=0.0015, beta_end=0.0195) +scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="scaled_linear_beta", beta_start=0.0015, beta_end=0.0195) # - # ### Scaling factor From c736c8ab56b871d329f76c5a17a8913024776268 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 4 Apr 2023 13:57:11 +0100 Subject: [PATCH 08/14] Update Signed-off-by: Eric Kerfoot --- generative/networks/schedulers/__init__.py | 1 + generative/networks/schedulers/ddim.py | 2 +- generative/networks/schedulers/ddpm.py | 2 +- generative/networks/schedulers/pndm.py | 2 +- generative/networks/schedulers/scheduler.py | 14 +++++++++++++- 5 files changed, 17 insertions(+), 4 deletions(-) diff --git a/generative/networks/schedulers/__init__.py b/generative/networks/schedulers/__init__.py index bb2eb347..4e19e5ef 100644 --- a/generative/networks/schedulers/__init__.py +++ b/generative/networks/schedulers/__init__.py @@ -11,6 +11,7 @@ from __future__ import annotations +from .scheduler import Scheduler from .ddim import DDIMScheduler from .ddpm import DDPMScheduler from .pndm import PNDMScheduler diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index fa2c6268..48fc7629 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -37,7 +37,7 @@ from monai.utils import StrEnum -from .scheduler import Scheduler +from generative.networks.schedulers import Scheduler class DDIMPredictionType(StrEnum): diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index ed025103..0216ce85 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -37,7 +37,7 @@ from monai.utils import StrEnum -from .scheduler import Scheduler +from generative.networks.schedulers import Scheduler class DDPMVarianceType(StrEnum): diff --git a/generative/networks/schedulers/pndm.py b/generative/networks/schedulers/pndm.py index 83cb323c..99891937 100644 --- a/generative/networks/schedulers/pndm.py +++ b/generative/networks/schedulers/pndm.py @@ -39,7 +39,7 @@ from monai.utils import StrEnum -from .scheduler import Scheduler +from generative.networks.schedulers import Scheduler class PNDMPredictionType(StrEnum): diff --git a/generative/networks/schedulers/scheduler.py b/generative/networks/schedulers/scheduler.py index 624424be..23cdff47 100644 --- a/generative/networks/schedulers/scheduler.py +++ b/generative/networks/schedulers/scheduler.py @@ -114,7 +114,19 @@ def _cosine_beta(num_train_timesteps, s=8e-3): class Scheduler(nn.Module): """ - Base class for other schedulers based on a beta noise schedule. + Base class for other schedulers based on a noise schedule function. + + This class is meant as the base for other schedulers which implement their own way of sampling or stepping. Here + the class defines beta, alpha, and alpha_cumprod values from a noise schedule function named with `schedule`, + which is the name of a component in NoiseSchedules. These components must all be callables which return either + the beta schedule alone or a triple containing (betas, alphas, alphas_cumprod) values. New schedule functions + can be provided by using the NoiseSchedules.add_def, for example: + + .. code-block:: python + + @NoiseSchedules.add_def("linear_beta", "Linear beta schedule") + def _linear_beta(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): + return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) Args: num_train_timesteps: number of diffusion steps used to train the model. From 43c8da73c24d103553560e5ac4282d236ad455ec Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 4 Apr 2023 14:16:49 +0100 Subject: [PATCH 09/14] Update Signed-off-by: Eric Kerfoot --- generative/networks/schedulers/scheduler.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/generative/networks/schedulers/scheduler.py b/generative/networks/schedulers/scheduler.py index 23cdff47..7f3f05e1 100644 --- a/generative/networks/schedulers/scheduler.py +++ b/generative/networks/schedulers/scheduler.py @@ -123,11 +123,14 @@ class Scheduler(nn.Module): can be provided by using the NoiseSchedules.add_def, for example: .. code-block:: python + from generative.networks.schedulers import NoiseSchedules, DDPMScheduler - @NoiseSchedules.add_def("linear_beta", "Linear beta schedule") - def _linear_beta(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): + @NoiseSchedules.add_def("my_beta_schedule", "Some description of your function") + def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + scheduler = DDPMScheduler(1000, "my_beta_schedule") + Args: num_train_timesteps: number of diffusion steps used to train the model. schedule: member of NoiseSchedules, From a79b75d79388059971e44203b4b531be1c1ba60f Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 5 Apr 2023 17:38:59 +0100 Subject: [PATCH 10/14] Fixes Signed-off-by: Eric Kerfoot --- generative/networks/schedulers/__init__.py | 4 +-- generative/networks/schedulers/ddim.py | 16 +++++------- generative/networks/schedulers/ddpm.py | 5 ++-- generative/networks/schedulers/pndm.py | 5 ++-- generative/networks/schedulers/scheduler.py | 9 ++++--- generative/utils/component_store.py | 29 ++++++++++++++++++++- 6 files changed, 46 insertions(+), 22 deletions(-) diff --git a/generative/networks/schedulers/__init__.py b/generative/networks/schedulers/__init__.py index 4e19e5ef..359576cd 100644 --- a/generative/networks/schedulers/__init__.py +++ b/generative/networks/schedulers/__init__.py @@ -9,9 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - -from .scheduler import Scheduler +from .scheduler import Scheduler, NoiseSchedules from .ddim import DDIMScheduler from .ddpm import DDPMScheduler from .pndm import PNDMScheduler diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index 48fc7629..1ea8de79 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -33,7 +33,6 @@ import numpy as np import torch -import torch.nn as nn from monai.utils import StrEnum @@ -79,15 +78,15 @@ def __init__( set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = DDIMPredictionType.EPSILON, - **schedule_args + **schedule_args, ) -> None: - super().__init__(num_train_timesteps,schedule,**schedule_args) + super().__init__(num_train_timesteps, schedule, **schedule_args) if prediction_type not in DDIMPredictionType.__members__.values(): raise ValueError(f"Argument `prediction_type` must be a member of DDIMPredictionType") self.prediction_type = prediction_type - + # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or @@ -185,11 +184,11 @@ def step( # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.prediction_type == DDIMPredictionType.EPSILON: - pred_original_sample = (sample - (beta_prod_t ** 0.5) * model_output) / (alpha_prod_t ** 0.5) + pred_original_sample = (sample - (beta_prod_t**0.5) * model_output) / (alpha_prod_t**0.5) pred_epsilon = model_output elif self.prediction_type == DDIMPredictionType.SAMPLE: pred_original_sample = model_output - pred_epsilon = (sample - (alpha_prod_t ** 0.5) * pred_original_sample) / (beta_prod_t ** 0.5) + pred_epsilon = (sample - (alpha_prod_t**0.5) * pred_original_sample) / (beta_prod_t**0.5) elif self.prediction_type == DDIMPredictionType.V_PREDICTION: pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample @@ -201,13 +200,13 @@ def step( # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(timestep, prev_timestep) - std_dev_t = eta * variance ** 0.5 + std_dev_t = eta * variance**0.5 # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * pred_epsilon # 7. compute x_t-1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction + pred_prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction if eta > 0: # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 @@ -278,4 +277,3 @@ def reversed_step( pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction return pred_post_sample, pred_original_sample - diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index 0216ce85..e91cdaa7 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -33,7 +33,6 @@ import numpy as np import torch -import torch.nn as nn from monai.utils import StrEnum @@ -81,9 +80,9 @@ def __init__( variance_type: str = DDPMVarianceType.FIXED_SMALL, clip_sample: bool = True, prediction_type: str = DDPMPredictionType.EPSILON, - **schedule_args + **schedule_args, ) -> None: - super().__init__(num_train_timesteps, schedule,**schedule_args) + super().__init__(num_train_timesteps, schedule, **schedule_args) if variance_type not in DDPMVarianceType.__members__.values(): raise ValueError("Argument `variance_type` must be a member of `DDPMVarianceType`") diff --git a/generative/networks/schedulers/pndm.py b/generative/networks/schedulers/pndm.py index 99891937..32116368 100644 --- a/generative/networks/schedulers/pndm.py +++ b/generative/networks/schedulers/pndm.py @@ -35,7 +35,6 @@ import numpy as np import torch -import torch.nn as nn from monai.utils import StrEnum @@ -84,9 +83,9 @@ def __init__( set_alpha_to_one: bool = False, prediction_type: str = PNDMPredictionType.EPSILON, steps_offset: int = 0, - **schedule_args + **schedule_args, ) -> None: - super().__init__(num_train_timesteps,schedule,**schedule_args) + super().__init__(num_train_timesteps, schedule, **schedule_args) if prediction_type not in PNDMPredictionType.__members__.values(): raise ValueError(f"Argument `prediction_type` must be a member of PNDMPredictionType") diff --git a/generative/networks/schedulers/scheduler.py b/generative/networks/schedulers/scheduler.py index 7f3f05e1..b68f4629 100644 --- a/generative/networks/schedulers/scheduler.py +++ b/generative/networks/schedulers/scheduler.py @@ -95,7 +95,7 @@ def _sigmoid_beta(num_train_timesteps, beta_start=1e-4, beta_end=2e-2, sig_range @NoiseSchedules.add_def("cosine_beta", "Cosine beta schedule") def _cosine_beta(num_train_timesteps, s=8e-3): """ - Cosine noise schedule returning. + Cosine noise schedule. Args: num_train_timesteps: number of timesteps @@ -130,7 +130,10 @@ def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) scheduler = DDPMScheduler(1000, "my_beta_schedule") - + + To see what noise functions are available, print the object NoiseSchedules to get a listing of stored objects + with their docstring descriptions. + Args: num_train_timesteps: number of diffusion steps used to train the model. schedule: member of NoiseSchedules, @@ -138,7 +141,7 @@ def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): schedule_args: arguments to pass to the schedule function """ - def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear", **schedule_args) -> None: + def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear_beta", **schedule_args) -> None: super().__init__() schedule_args["num_train_timesteps"] = num_train_timesteps noise_sched = NoiseSchedules[schedule](**schedule_args) diff --git a/generative/utils/component_store.py b/generative/utils/component_store.py index a088f6a2..6be1744f 100644 --- a/generative/utils/component_store.py +++ b/generative/utils/component_store.py @@ -10,6 +10,7 @@ # limitations under the License. from __future__ import annotations +from textwrap import dedent, indent from collections import namedtuple from keyword import iskeyword from typing import TypeVar, Callable, Any, Dict, Iterable @@ -25,7 +26,27 @@ def is_variable(name): class ComponentStore: """ Represents a storage object for other objects (specifically functions) keyed to a name with a description. + These objects act as global named places for storing components for objects parameterised by component names. + Typically this is functions although other objects can be added. Printing a component store will produce a + list of members along with their docstring information if present. + + Example: + + .. code-block:: python + + TestStore = ComponentStore("Test Store", "A test store for demo purposes") + + @TestStore.add_def("my_func_name", "Some description of your function") + def _my_func(a, b): + '''A description of your function here.''' + return a * b + + print(TestStore) # will print out name, description, and 'my_func_name' with the docstring + + func = TestStore["my_func_name"] + result = func(7, 6) + """ _Component = namedtuple("Component", ("description", "value")) # internal value pair @@ -70,7 +91,13 @@ def __iter__(self) -> Iterable: def __str__(self): result = f"Component Store '{self.name}': {self.description}\nAvailable components:" for k, v in self.components.items(): - result += f"\n {k}: {v.description}" + result += f"\n* {k}:" + + if hasattr(v.value, "__doc__"): + doc = indent(dedent(v.value.__doc__.lstrip("\n").rstrip()), " ") + result += f"\n{doc}\n" + else: + result += f" {v.description}" return result From c861983b61aa2fd8edb13f96f0dea5ab77677768 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 5 Apr 2023 18:50:33 +0100 Subject: [PATCH 11/14] Updates from comments Signed-off-by: Eric Kerfoot --- generative/networks/schedulers/ddim.py | 19 +++++++++------ generative/networks/schedulers/ddpm.py | 25 +++++++++++-------- generative/networks/schedulers/pndm.py | 17 +++++++------ generative/networks/schedulers/scheduler.py | 27 +++++++++++++-------- 4 files changed, 52 insertions(+), 36 deletions(-) diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index 1ea8de79..f1d321e5 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -40,6 +40,13 @@ class DDIMPredictionType(StrEnum): + """ + Set of valid prediction type names for the DDIM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + sample: directly predicting the noisy sample + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ EPSILON = "epsilon" SAMPLE = "sample" V_PREDICTION = "v_prediction" @@ -53,10 +60,7 @@ class DDIMScheduler(Scheduler): Args: num_train_timesteps: number of diffusion steps used to train the model. - beta_start: the starting `beta` value of inference. - beta_end: the final `beta` value. - beta_schedule: {``"linear"``, ``"scaled_linear"``} - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is @@ -64,10 +68,9 @@ class DDIMScheduler(Scheduler): steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. - prediction_type: {``"epsilon"``, ``"sample"``, ``"v_prediction"``} - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) + prediction_type: member of DDPMPredictionType + schedule_args: arguments to pass to the schedule function + """ def __init__( diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index e91cdaa7..6a1c0be0 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -40,6 +40,10 @@ class DDPMVarianceType(StrEnum): + """ + Valid names for DDPM Scheduler's `variance_type` argument. Options to clip the variance used when adding noise + to the denoised sample. + """ FIXED_SMALL = "fixed_small" FIXED_LARGE = "fixed_large" LEARNED = "learned" @@ -47,6 +51,13 @@ class DDPMVarianceType(StrEnum): class DDPMPredictionType(StrEnum): + """ + Set of valid prediction type names for the DDPM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + sample: directly predicting the noisy sample + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ EPSILON = "epsilon" SAMPLE = "sample" V_PREDICTION = "v_prediction" @@ -60,17 +71,11 @@ class DDPMScheduler(Scheduler): Args: num_train_timesteps: number of diffusion steps used to train the model. - beta_start: the starting `beta` value of inference. - beta_end: the final `beta` value. - beta_schedule: member of BetaSchedules - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. - variance_type: member of VarianceType - options to clip the variance used when adding noise to the denoised sample. + schedule: member of NoiseSchedules, name of noise schedule function in component store + variance_type: member of DDPMVarianceType clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. - prediction_type: {``"epsilon"``, ``"sample"``, ``"v_prediction"``} - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) + prediction_type: member of DDPMPredictionType + schedule_args: arguments to pass to the schedule function """ def __init__( diff --git a/generative/networks/schedulers/pndm.py b/generative/networks/schedulers/pndm.py index 32116368..cb59dd5b 100644 --- a/generative/networks/schedulers/pndm.py +++ b/generative/networks/schedulers/pndm.py @@ -42,6 +42,12 @@ class PNDMPredictionType(StrEnum): + """ + Set of valid prediction type names for the PNDM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ EPSILON = "epsilon" V_PREDICTION = "v_prediction" @@ -54,10 +60,7 @@ class PNDMScheduler(Scheduler): Args: num_train_timesteps: number of diffusion steps used to train the model. - beta_start: the starting `beta` value of inference. - beta_end: the final `beta` value. - beta_schedule: {``"linear"``, ``"scaled_linear"``} - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store skip_prk_steps: allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required before plms step. @@ -65,14 +68,12 @@ class PNDMScheduler(Scheduler): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the value of alpha at step 0. - prediction_type: {``"epsilon"``, ``"v_prediction"``} - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) + prediction_type: member of DDPMPredictionType steps_offset: an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. + schedule_args: arguments to pass to the schedule function """ def __init__( diff --git a/generative/networks/schedulers/scheduler.py b/generative/networks/schedulers/scheduler.py index b68f4629..0bc3fe9e 100644 --- a/generative/networks/schedulers/scheduler.py +++ b/generative/networks/schedulers/scheduler.py @@ -43,7 +43,7 @@ @NoiseSchedules.add_def("linear_beta", "Linear beta schedule") -def _linear_beta(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): +def _linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): """ Linear beta noise schedule function. @@ -59,7 +59,7 @@ def _linear_beta(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): @NoiseSchedules.add_def("scaled_linear_beta", "Scaled linear beta schedule") -def _scaled_linear_beta(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): +def _scaled_linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): """ Scaled linear beta noise schedule function. @@ -75,7 +75,7 @@ def _scaled_linear_beta(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): @NoiseSchedules.add_def("sigmoid_beta", "Sigmoid beta schedule") -def _sigmoid_beta(num_train_timesteps, beta_start=1e-4, beta_end=2e-2, sig_range=6): +def _sigmoid_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2, sig_range: float = 6): """ Sigmoid beta noise schedule function. @@ -92,14 +92,14 @@ def _sigmoid_beta(num_train_timesteps, beta_start=1e-4, beta_end=2e-2, sig_range return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start -@NoiseSchedules.add_def("cosine_beta", "Cosine beta schedule") -def _cosine_beta(num_train_timesteps, s=8e-3): +@NoiseSchedules.add_def("cosine", "Cosine schedule") +def _cosine_beta(num_train_timesteps: int, s: float = 8e-3): """ - Cosine noise schedule. + Cosine noise schedule, see https://arxiv.org/abs/2102.09672 Args: num_train_timesteps: number of timesteps - s: smoothing factor, default 8e-3 + s: smoothing factor, default 8e-3 (see referenced paper) Returns: (betas, alphas, alpha_cumprod) values @@ -129,10 +129,17 @@ class Scheduler(nn.Module): def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - scheduler = DDPMScheduler(1000, "my_beta_schedule") + scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="my_beta_schedule") - To see what noise functions are available, print the object NoiseSchedules to get a listing of stored objects - with their docstring descriptions. + All such functions should have an initial positional integer argument `num_train_timesteps` stating the number of + timesteps the schedule is for, otherwise any other arguments can be given which will be passed by keyword through + the constructor's `schedule_args` value. To see what noise functions are available, print the object NoiseSchedules + to get a listing of stored objects with their docstring descriptions. + + Note: in previous versions of the schedulers the argument `schedule_beta` was used to state the beta schedule + type, this now replaced with `schedule` and most names used with the previous argument now have "_beta" appended + to them, eg. 'schedule_beta="linear"' -> 'schedule="linear_beta"'. The `beta_start` and `beta_end` arguments are + still used for some schedules but these are provided as keyword arguments now. Args: num_train_timesteps: number of diffusion steps used to train the model. From 73d21d9b489e57ab696cdbfdf58fc809523a8592 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Tue, 16 May 2023 14:13:39 +0100 Subject: [PATCH 12/14] Update generative/networks/schedulers/ddpm.py Co-authored-by: Mark Graham Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- generative/networks/schedulers/ddpm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index 6a1c0be0..944e534a 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -93,7 +93,7 @@ def __init__( raise ValueError("Argument `variance_type` must be a member of `DDPMVarianceType`") if prediction_type not in DDPMPredictionType.__members__.values(): - raise ValueError("Argument `prediction_type` must be a member of `DDPMPRedictionType`") + raise ValueError("Argument `prediction_type` must be a member of `DDPMPredictionType`") self.clip_sample = clip_sample self.variance_type = variance_type From 71b3c55d36f68ffe496629d5b653468711d2cfb9 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 16 May 2023 14:21:15 +0100 Subject: [PATCH 13/14] Autofixin' Signed-off-by: Eric Kerfoot --- generative/networks/schedulers/__init__.py | 4 +++- generative/networks/schedulers/ddim.py | 6 +++--- generative/networks/schedulers/ddpm.py | 7 ++++--- generative/networks/schedulers/pndm.py | 4 ++-- generative/networks/schedulers/scheduler.py | 3 ++- generative/utils/__init__.py | 2 +- generative/utils/component_store.py | 5 +++-- generative/utils/misc.py | 1 + ...2d_classifierfree_guidance_anomalydetection_tutorial.py | 1 + .../anomaly_detection_with_transformers.py | 2 -- .../tutorial_segmentation_with_ddpm.py | 1 - 11 files changed, 20 insertions(+), 16 deletions(-) diff --git a/generative/networks/schedulers/__init__.py b/generative/networks/schedulers/__init__.py index 359576cd..29e9020d 100644 --- a/generative/networks/schedulers/__init__.py +++ b/generative/networks/schedulers/__init__.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .scheduler import Scheduler, NoiseSchedules +from __future__ import annotations + from .ddim import DDIMScheduler from .ddpm import DDPMScheduler from .pndm import PNDMScheduler +from .scheduler import NoiseSchedules, Scheduler diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index f1d321e5..10cb5031 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -33,7 +33,6 @@ import numpy as np import torch - from monai.utils import StrEnum from generative.networks.schedulers import Scheduler @@ -42,11 +41,12 @@ class DDIMPredictionType(StrEnum): """ Set of valid prediction type names for the DDIM scheduler's `prediction_type` argument. - + epsilon: predicting the noise of the diffusion process sample: directly predicting the noisy sample v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf """ + EPSILON = "epsilon" SAMPLE = "sample" V_PREDICTION = "v_prediction" @@ -70,7 +70,7 @@ class DDIMScheduler(Scheduler): stable diffusion. prediction_type: member of DDPMPredictionType schedule_args: arguments to pass to the schedule function - + """ def __init__( diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index 944e534a..17d2112c 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -33,7 +33,6 @@ import numpy as np import torch - from monai.utils import StrEnum from generative.networks.schedulers import Scheduler @@ -41,9 +40,10 @@ class DDPMVarianceType(StrEnum): """ - Valid names for DDPM Scheduler's `variance_type` argument. Options to clip the variance used when adding noise + Valid names for DDPM Scheduler's `variance_type` argument. Options to clip the variance used when adding noise to the denoised sample. """ + FIXED_SMALL = "fixed_small" FIXED_LARGE = "fixed_large" LEARNED = "learned" @@ -53,11 +53,12 @@ class DDPMVarianceType(StrEnum): class DDPMPredictionType(StrEnum): """ Set of valid prediction type names for the DDPM scheduler's `prediction_type` argument. - + epsilon: predicting the noise of the diffusion process sample: directly predicting the noisy sample v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf """ + EPSILON = "epsilon" SAMPLE = "sample" V_PREDICTION = "v_prediction" diff --git a/generative/networks/schedulers/pndm.py b/generative/networks/schedulers/pndm.py index cb59dd5b..5776e5e5 100644 --- a/generative/networks/schedulers/pndm.py +++ b/generative/networks/schedulers/pndm.py @@ -35,7 +35,6 @@ import numpy as np import torch - from monai.utils import StrEnum from generative.networks.schedulers import Scheduler @@ -44,10 +43,11 @@ class PNDMPredictionType(StrEnum): """ Set of valid prediction type names for the PNDM scheduler's `prediction_type` argument. - + epsilon: predicting the noise of the diffusion process v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf """ + EPSILON = "epsilon" V_PREDICTION = "v_prediction" diff --git a/generative/networks/schedulers/scheduler.py b/generative/networks/schedulers/scheduler.py index 0bc3fe9e..60c957ff 100644 --- a/generative/networks/schedulers/scheduler.py +++ b/generative/networks/schedulers/scheduler.py @@ -31,14 +31,15 @@ from __future__ import annotations + from typing import Callable + import numpy as np import torch import torch.nn as nn from generative.utils import ComponentStore, unsqueeze_right - NoiseSchedules = ComponentStore("NoiseSchedules", "Functions to generate noise schedules") diff --git a/generative/utils/__init__.py b/generative/utils/__init__.py index bf62a696..6eabaaf2 100644 --- a/generative/utils/__init__.py +++ b/generative/utils/__init__.py @@ -11,6 +11,6 @@ from __future__ import annotations -from .enums import AdversarialIterationEvents, AdversarialKeys from .component_store import ComponentStore +from .enums import AdversarialIterationEvents, AdversarialKeys from .misc import * diff --git a/generative/utils/component_store.py b/generative/utils/component_store.py index 6be1744f..31ad8460 100644 --- a/generative/utils/component_store.py +++ b/generative/utils/component_store.py @@ -10,10 +10,11 @@ # limitations under the License. from __future__ import annotations -from textwrap import dedent, indent + from collections import namedtuple from keyword import iskeyword -from typing import TypeVar, Callable, Any, Dict, Iterable +from textwrap import dedent, indent +from typing import Any, Callable, Dict, Iterable, TypeVar T = TypeVar("T") diff --git a/generative/utils/misc.py b/generative/utils/misc.py index c8a00177..aea74a81 100644 --- a/generative/utils/misc.py +++ b/generative/utils/misc.py @@ -10,6 +10,7 @@ # limitations under the License. from __future__ import annotations + from typing import TypeVar T = TypeVar("T") diff --git a/tutorials/generative/anomaly_detection/2d_classifierfree_guidance_anomalydetection_tutorial.py b/tutorials/generative/anomaly_detection/2d_classifierfree_guidance_anomalydetection_tutorial.py index 706b82c5..fb3d80b0 100644 --- a/tutorials/generative/anomaly_detection/2d_classifierfree_guidance_anomalydetection_tutorial.py +++ b/tutorials/generative/anomaly_detection/2d_classifierfree_guidance_anomalydetection_tutorial.py @@ -395,6 +395,7 @@ # %% [markdown] # ### Visualize anomaly map + # %% def visualize(img): _min = img.min() diff --git a/tutorials/generative/anomaly_detection/anomaly_detection_with_transformers.py b/tutorials/generative/anomaly_detection/anomaly_detection_with_transformers.py index fadae0bc..2db60a88 100644 --- a/tutorials/generative/anomaly_detection/anomaly_detection_with_transformers.py +++ b/tutorials/generative/anomaly_detection/anomaly_detection_with_transformers.py @@ -335,7 +335,6 @@ progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=110) progress_bar.set_description(f"Epoch {epoch}") for step, batch in progress_bar: - images = batch["image"].to(device) optimizer.zero_grad(set_to_none=True) @@ -358,7 +357,6 @@ val_loss = 0 with torch.no_grad(): for val_step, batch in enumerate(val_loader, start=1): - images = batch["image"].to(device) logits, quantizations_target, _ = inferer( diff --git a/tutorials/generative/image_to_image_translation/tutorial_segmentation_with_ddpm.py b/tutorials/generative/image_to_image_translation/tutorial_segmentation_with_ddpm.py index b7fd3d6e..eaa08f5b 100644 --- a/tutorials/generative/image_to_image_translation/tutorial_segmentation_with_ddpm.py +++ b/tutorials/generative/image_to_image_translation/tutorial_segmentation_with_ddpm.py @@ -368,7 +368,6 @@ def dice_coeff(im1, im2, empty_score=1.0): # + for i in range(len(ensemble)): - prediction = torch.where(ensemble[i] > 0.5, 1, 0).float() # a binary mask is obtained via thresholding score = dice_coeff( prediction[0, 0].cpu(), inputlabel.cpu() From 6e4769dbfef06c28e02e2136de523fd97a6c6ed9 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 16 May 2023 16:49:20 +0100 Subject: [PATCH 14/14] Fixes Signed-off-by: Eric Kerfoot --- generative/networks/schedulers/ddim.py | 4 +- generative/networks/schedulers/ddpm.py | 2 +- generative/networks/schedulers/pndm.py | 4 +- generative/networks/schedulers/scheduler.py | 3 -- generative/utils/__init__.py | 2 +- tests/test_compute_multiscalessim_metric.py | 12 ++++-- tests/test_diffusion_model_unet.py | 46 +++++++++++---------- 7 files changed, 38 insertions(+), 35 deletions(-) diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index 10cb5031..7c3de648 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -35,7 +35,7 @@ import torch from monai.utils import StrEnum -from generative.networks.schedulers import Scheduler +from .scheduler import Scheduler class DDIMPredictionType(StrEnum): @@ -86,7 +86,7 @@ def __init__( super().__init__(num_train_timesteps, schedule, **schedule_args) if prediction_type not in DDIMPredictionType.__members__.values(): - raise ValueError(f"Argument `prediction_type` must be a member of DDIMPredictionType") + raise ValueError("Argument `prediction_type` must be a member of DDIMPredictionType") self.prediction_type = prediction_type diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index 17d2112c..e543502c 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -35,7 +35,7 @@ import torch from monai.utils import StrEnum -from generative.networks.schedulers import Scheduler +from .scheduler import Scheduler class DDPMVarianceType(StrEnum): diff --git a/generative/networks/schedulers/pndm.py b/generative/networks/schedulers/pndm.py index 5776e5e5..b729315f 100644 --- a/generative/networks/schedulers/pndm.py +++ b/generative/networks/schedulers/pndm.py @@ -37,7 +37,7 @@ import torch from monai.utils import StrEnum -from generative.networks.schedulers import Scheduler +from .scheduler import Scheduler class PNDMPredictionType(StrEnum): @@ -89,7 +89,7 @@ def __init__( super().__init__(num_train_timesteps, schedule, **schedule_args) if prediction_type not in PNDMPredictionType.__members__.values(): - raise ValueError(f"Argument `prediction_type` must be a member of PNDMPredictionType") + raise ValueError("Argument `prediction_type` must be a member of PNDMPredictionType") self.prediction_type = prediction_type diff --git a/generative/networks/schedulers/scheduler.py b/generative/networks/schedulers/scheduler.py index 60c957ff..bf153b8b 100644 --- a/generative/networks/schedulers/scheduler.py +++ b/generative/networks/schedulers/scheduler.py @@ -32,9 +32,6 @@ from __future__ import annotations -from typing import Callable - -import numpy as np import torch import torch.nn as nn diff --git a/generative/utils/__init__.py b/generative/utils/__init__.py index 6eabaaf2..08a1b9b3 100644 --- a/generative/utils/__init__.py +++ b/generative/utils/__init__.py @@ -13,4 +13,4 @@ from .component_store import ComponentStore from .enums import AdversarialIterationEvents, AdversarialKeys -from .misc import * +from .misc import unsqueeze_left, unsqueeze_right diff --git a/tests/test_compute_multiscalessim_metric.py b/tests/test_compute_multiscalessim_metric.py index 1f385fd4..85b96991 100644 --- a/tests/test_compute_multiscalessim_metric.py +++ b/tests/test_compute_multiscalessim_metric.py @@ -59,18 +59,22 @@ def test3d_gaussian(self): expected_value = 0.061796 self.assertTrue(expected_value - result.item() < 0.000001) - def input_ill_input_shape(self): + def input_ill_input_shape2d(self): + metric = MultiScaleSSIMMetric(spatial_dims=3, weights=[0.5, 0.5]) + with self.assertRaises(ValueError): - metric = MultiScaleSSIMMetric(spatial_dims=3, weights=[0.5, 0.5]) metric(torch.randn(1, 1, 64, 64), torch.randn(1, 1, 64, 64)) + def input_ill_input_shape3d(self): + metric = MultiScaleSSIMMetric(spatial_dims=2, weights=[0.5, 0.5]) + with self.assertRaises(ValueError): - metric = MultiScaleSSIMMetric(spatial_dims=2, weights=[0.5, 0.5]) metric(torch.randn(1, 1, 64, 64, 64), torch.randn(1, 1, 64, 64, 64)) def small_inputs(self): + metric = MultiScaleSSIMMetric(spatial_dims=2) + with self.assertRaises(ValueError): - metric = MultiScaleSSIMMetric(spatial_dims=2) metric(torch.randn(1, 1, 16, 16, 16), torch.randn(1, 1, 16, 16, 16)) diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index ebda9d31..1769040d 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -331,18 +331,19 @@ def test_with_conditioning_cross_attention_dim_none(self): ) def test_context_with_conditioning_none(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=False, + transformer_num_layers=1, + norm_num_groups=8, + ) + with self.assertRaises(ValueError): - net = DiffusionModelUNet( - spatial_dims=2, - in_channels=1, - out_channels=1, - num_res_blocks=1, - num_channels=(8, 8, 8), - attention_levels=(False, False, True), - with_conditioning=False, - transformer_num_layers=1, - norm_num_groups=8, - ) with eval_mode(net): net.forward( x=torch.rand((1, 1, 16, 32)), @@ -371,18 +372,19 @@ def test_shape_conditioned_models_class_conditioning(self): self.assertEqual(result.shape, (1, 1, 16, 32)) def test_conditioned_models_no_class_labels(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + with self.assertRaises(ValueError): - net = DiffusionModelUNet( - spatial_dims=2, - in_channels=1, - out_channels=1, - num_res_blocks=1, - num_channels=(8, 8, 8), - attention_levels=(False, False, True), - norm_num_groups=8, - num_head_channels=8, - num_class_embeds=2, - ) net.forward(x=torch.rand((1, 1, 16, 32)), timesteps=torch.randint(0, 1000, (1,)).long()) def test_model_num_channels_not_same_size_of_attention_levels(self):