diff --git a/docs/source/en/tutorials/basic_training.mdx b/docs/source/en/tutorials/basic_training.mdx index 435de38d832f..52ce7c71fa68 100644 --- a/docs/source/en/tutorials/basic_training.mdx +++ b/docs/source/en/tutorials/basic_training.mdx @@ -344,7 +344,7 @@ Now you can wrap all these components together in a training loop with 🤗 Acce ... # Sample a random timestep for each image ... timesteps = torch.randint( -... 0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device +... 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device ... ).long() ... # Add noise to the clean images according to the noise magnitude at each timestep diff --git a/docs/source/en/using-diffusers/contribute_pipeline.mdx b/docs/source/en/using-diffusers/contribute_pipeline.mdx index ce3f3e823252..8ee6d6ae4fb1 100644 --- a/docs/source/en/using-diffusers/contribute_pipeline.mdx +++ b/docs/source/en/using-diffusers/contribute_pipeline.mdx @@ -62,7 +62,7 @@ class UnetSchedulerOneForwardPipeline(DiffusionPipeline): def __call__(self): image = torch.randn( - (1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), ) timestep = 1 @@ -108,7 +108,7 @@ class UnetSchedulerOneForwardPipeline(DiffusionPipeline): def __call__(self): image = torch.randn( - (1, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), ) timestep = 1 diff --git a/docs/source/en/using-diffusers/custom_pipeline_overview.mdx b/docs/source/en/using-diffusers/custom_pipeline_overview.mdx index 5c342a5a88e9..934e639983d2 100644 --- a/docs/source/en/using-diffusers/custom_pipeline_overview.mdx +++ b/docs/source/en/using-diffusers/custom_pipeline_overview.mdx @@ -89,7 +89,9 @@ class MyPipeline(DiffusionPipeline): @torch.no_grad() def __call__(self, batch_size: int = 1, num_inference_steps: int = 50): # Sample gaussian noise to begin loop - image = torch.randn((batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)) + image = torch.randn( + (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size) + ) image = image.to(self.device) diff --git a/examples/community/bit_diffusion.py b/examples/community/bit_diffusion.py index c778b6cc6c71..18d5fca5619e 100644 --- a/examples/community/bit_diffusion.py +++ b/examples/community/bit_diffusion.py @@ -238,7 +238,7 @@ def __call__( **kwargs, ) -> Union[Tuple, ImagePipelineOutput]: latents = torch.randn( - (batch_size, self.unet.in_channels, height, width), + (batch_size, self.unet.config.in_channels, height, width), generator=generator, ) latents = decimal_to_bits(latents) * self.bit_scale diff --git a/examples/community/clip_guided_stable_diffusion.py b/examples/community/clip_guided_stable_diffusion.py index fbb233dccd7a..3f4ab2ab9f4a 100644 --- a/examples/community/clip_guided_stable_diffusion.py +++ b/examples/community/clip_guided_stable_diffusion.py @@ -254,7 +254,7 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": diff --git a/examples/community/clip_guided_stable_diffusion_img2img.py b/examples/community/clip_guided_stable_diffusion_img2img.py index c3dee5aa9e9a..a72a5a127c72 100644 --- a/examples/community/clip_guided_stable_diffusion_img2img.py +++ b/examples/community/clip_guided_stable_diffusion_img2img.py @@ -414,7 +414,7 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": diff --git a/examples/community/composable_stable_diffusion.py b/examples/community/composable_stable_diffusion.py index 35512395ace6..017ad98f291a 100644 --- a/examples/community/composable_stable_diffusion.py +++ b/examples/community/composable_stable_diffusion.py @@ -513,7 +513,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/examples/community/imagic_stable_diffusion.py b/examples/community/imagic_stable_diffusion.py index dc8ce5f259dc..56bd381a9e65 100644 --- a/examples/community/imagic_stable_diffusion.py +++ b/examples/community/imagic_stable_diffusion.py @@ -424,7 +424,7 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_shape = (1, self.unet.in_channels, height // 8, width // 8) + latents_shape = (1, self.unet.config.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype if self.device.type == "mps": # randn does not exist on mps diff --git a/examples/community/interpolate_stable_diffusion.py b/examples/community/interpolate_stable_diffusion.py index c86e7372a2e1..8f33db71b9f3 100644 --- a/examples/community/interpolate_stable_diffusion.py +++ b/examples/community/interpolate_stable_diffusion.py @@ -320,7 +320,7 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": @@ -416,7 +416,7 @@ def embed_text(self, text): def get_noise(self, seed, dtype=torch.float32, height=512, width=512): """Takes in random seed and returns corresponding noise vector""" return torch.randn( - (1, self.unet.in_channels, height // 8, width // 8), + (1, self.unet.config.in_channels, height // 8, width // 8), generator=torch.Generator(device=self.device).manual_seed(seed), device=self.device, dtype=dtype, diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index b4863f65abf7..e912ad5244be 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -627,7 +627,7 @@ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, dev if image is None: shape = ( batch_size, - self.unet.in_channels, + self.unet.config.in_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) diff --git a/examples/community/lpw_stable_diffusion_onnx.py b/examples/community/lpw_stable_diffusion_onnx.py index 9aa7d47eeab0..e756097cb7c3 100644 --- a/examples/community/lpw_stable_diffusion_onnx.py +++ b/examples/community/lpw_stable_diffusion_onnx.py @@ -486,7 +486,7 @@ def __init__( self.__init__additional__() def __init__additional__(self): - self.unet_in_channels = 4 + self.unet.config.in_channels = 4 self.vae_scale_factor = 8 def _encode_prompt( @@ -621,7 +621,7 @@ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, gen if image is None: shape = ( batch_size, - self.unet_in_channels, + self.unet.config.in_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) diff --git a/examples/community/magic_mix.py b/examples/community/magic_mix.py index b1d69ec84576..4eb99cb96b42 100644 --- a/examples/community/magic_mix.py +++ b/examples/community/magic_mix.py @@ -93,7 +93,7 @@ def __call__( torch.manual_seed(seed) noise = torch.randn( - (1, self.unet.in_channels, height // 8, width // 8), + (1, self.unet.config.in_channels, height // 8, width // 8), ).to(self.device) latents = self.scheduler.add_noise( diff --git a/examples/community/multilingual_stable_diffusion.py b/examples/community/multilingual_stable_diffusion.py index f920c4cd59da..ff6c7e68f783 100644 --- a/examples/community/multilingual_stable_diffusion.py +++ b/examples/community/multilingual_stable_diffusion.py @@ -355,7 +355,7 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": diff --git a/examples/community/sd_text2img_k_diffusion.py b/examples/community/sd_text2img_k_diffusion.py index 78bd7566e6ca..246c3d8c1928 100755 --- a/examples/community/sd_text2img_k_diffusion.py +++ b/examples/community/sd_text2img_k_diffusion.py @@ -433,7 +433,7 @@ def __call__( sigmas = sigmas.to(text_embeddings.dtype) # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/examples/community/seed_resize_stable_diffusion.py b/examples/community/seed_resize_stable_diffusion.py index db7c71124254..5891b9fb11a8 100644 --- a/examples/community/seed_resize_stable_diffusion.py +++ b/examples/community/seed_resize_stable_diffusion.py @@ -262,8 +262,8 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) - latents_shape_reference = (batch_size * num_images_per_prompt, self.unet.in_channels, 64, 64) + latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8) + latents_shape_reference = (batch_size * num_images_per_prompt, self.unet.config.in_channels, 64, 64) latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": diff --git a/examples/community/speech_to_image_diffusion.py b/examples/community/speech_to_image_diffusion.py index 45050137c768..55d805bc8c32 100644 --- a/examples/community/speech_to_image_diffusion.py +++ b/examples/community/speech_to_image_diffusion.py @@ -190,7 +190,7 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": diff --git a/examples/community/wildcard_stable_diffusion.py b/examples/community/wildcard_stable_diffusion.py index 7dd4640243a8..aec79fb8e12e 100644 --- a/examples/community/wildcard_stable_diffusion.py +++ b/examples/community/wildcard_stable_diffusion.py @@ -337,7 +337,7 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) + latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index a119e12f73d1..a6e0c1af3e1d 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -794,7 +794,7 @@ def collate_fn(examples): noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep diff --git a/examples/research_projects/lora/train_text_to_image_lora.py b/examples/research_projects/lora/train_text_to_image_lora.py index 9db2024bde1e..fd516fff9811 100644 --- a/examples/research_projects/lora/train_text_to_image_lora.py +++ b/examples/research_projects/lora/train_text_to_image_lora.py @@ -794,7 +794,7 @@ def collate_fn(examples): noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py index 321b94bc6cbb..61312fb3a4b3 100644 --- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py +++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py @@ -641,7 +641,7 @@ def collate_fn(examples): noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index d4d8dae608e3..f415461aaa09 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -804,7 +804,7 @@ def collate_fn(examples): bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index c85b339d5b7a..2d657abfa89d 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -707,7 +707,7 @@ def collate_fn(examples): bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index ce6e77b03f57..45930431351a 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -109,13 +109,6 @@ def register_to_config(self, **kwargs): # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, # or solve in a more general way. kwargs.pop("kwargs", None) - for key, value in kwargs.items(): - try: - setattr(self, key, value) - except AttributeError as err: - logger.error(f"Can't set {key} with value {value} for {self}") - raise err - if not hasattr(self, "_internal_dict"): internal_dict = kwargs else: diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 80e3412991cf..4598e1b4288c 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -99,8 +99,8 @@ def resize(self, images: PIL.Image.Image) -> PIL.Image.Image: Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor` """ w, h = images.size - w, h = (x - x % self.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor - images = images.resize((w, h), resample=PIL_INTERPOLATION[self.resample]) + w, h = (x - x % self.config.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor + images = images.resize((w, h), resample=PIL_INTERPOLATION[self.config.resample]) return images def preprocess( @@ -119,7 +119,7 @@ def preprocess( ) if isinstance(image[0], PIL.Image.Image): - if self.do_resize: + if self.config.do_resize: image = [self.resize(i) for i in image] image = [np.array(i).astype(np.float32) / 255.0 for i in image] image = np.stack(image, axis=0) # to np @@ -129,23 +129,27 @@ def preprocess( image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) image = self.numpy_to_pt(image) _, _, height, width = image.shape - if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0): + if self.config.do_resize and ( + height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 + ): raise ValueError( - f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.vae_scale_factor}" + f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}" f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" ) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) _, _, height, width = image.shape - if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0): + if self.config.do_resize and ( + height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 + ): raise ValueError( - f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.vae_scale_factor}" + f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}" f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" ) # expected range [0,1], normalize to [-1,1] - do_normalize = self.do_normalize + do_normalize = self.config.do_normalize if image.min() < 0: warnings.warn( "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 8f65c2357cac..5d1c54a9af25 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -18,7 +18,7 @@ import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, apply_forward_hook +from ..utils import BaseOutput, apply_forward_hook, deprecate from .modeling_utils import ModelMixin from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder @@ -120,9 +120,19 @@ def __init__( if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size ) - self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1))) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 + @property + def block_out_channels(self): + deprecate( + "block_out_channels", + "1.0.0", + "Accessing `block_out_channels` directly via vae.block_out_channels is deprecated. Please use `vae.config.block_out_channels instead`", + standard_warn=False, + ) + return self.config.block_out_channels + def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (Encoder, Decoder)): module.gradient_checkpointing = value diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index 34a1d2b5160e..c7755bb3ed45 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -19,7 +19,7 @@ import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, deprecate from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block @@ -190,6 +190,16 @@ def __init__( fc_dim=block_out_channels[-1] // 4, ) + @property + def in_channels(self): + deprecate( + "in_channels", + "1.0.0", + "Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead", + standard_warn=False, + ) + return self.config.in_channels + def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 2df6e60d88c9..d0f2a9cd8a22 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -18,7 +18,7 @@ import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, deprecate from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @@ -215,6 +215,16 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + @property + def in_channels(self): + deprecate( + "in_channels", + "1.0.0", + "Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead", + standard_warn=False, + ) + return self.config.in_channels + def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 72fc0b519ecf..3610231d19e6 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin -from ..utils import BaseOutput, logging +from ..utils import BaseOutput, deprecate, logging from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin @@ -412,6 +412,16 @@ def __init__( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) + @property + def in_channels(self): + deprecate( + "in_channels", + "1.0.0", + "Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead", + standard_warn=False, + ) + return self.config.in_channels + @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index c5bb8f9ac7b1..bf314b91116e 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -646,7 +646,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py b/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py index 1b88270cbbe6..8d8229e661e8 100644 --- a/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py +++ b/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py @@ -121,17 +121,17 @@ def __call__( self.scheduler.set_timesteps(steps) step_generator = step_generator or generator # For backwards compatibility - if type(self.unet.sample_size) == int: - self.unet.sample_size = (self.unet.sample_size, self.unet.sample_size) + if type(self.unet.config.sample_size) == int: + self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size) input_dims = self.get_input_dims() self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0]) if noise is None: noise = randn_tensor( ( batch_size, - self.unet.in_channels, - self.unet.sample_size[0], - self.unet.sample_size[1], + self.unet.config.in_channels, + self.unet.config.sample_size[0], + self.unet.config.sample_size[1], ), generator=generator, device=self.device, @@ -158,7 +158,7 @@ def __call__( images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1]) pixels_per_second = ( - self.unet.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length + self.unet.config.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length ) mask_start = int(mask_start_secs * pixels_per_second) mask_end = int(mask_end_secs * pixels_per_second) diff --git a/src/diffusers/pipelines/audioldm/pipeline_audioldm.py b/src/diffusers/pipelines/audioldm/pipeline_audioldm.py index b392cd4cc246..86a8fd659046 100644 --- a/src/diffusers/pipelines/audioldm/pipeline_audioldm.py +++ b/src/diffusers/pipelines/audioldm/pipeline_audioldm.py @@ -540,7 +540,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_waveforms_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py index 018e020491ce..1bfed086e8c6 100644 --- a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +++ b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py @@ -61,7 +61,7 @@ def __call__( to make generation deterministic. audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`): The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.* - `sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`. + `sample_size`, will be `audio_length_in_s` * `self.unet.config.sample_rate`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. @@ -73,27 +73,29 @@ def __call__( if audio_length_in_s is None: audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate - sample_size = audio_length_in_s * self.unet.sample_rate + sample_size = audio_length_in_s * self.unet.config.sample_rate down_scale_factor = 2 ** len(self.unet.up_blocks) if sample_size < 3 * down_scale_factor: raise ValueError( f"{audio_length_in_s} is too small. Make sure it's bigger or equal to" - f" {3 * down_scale_factor / self.unet.sample_rate}." + f" {3 * down_scale_factor / self.unet.config.sample_rate}." ) original_sample_size = int(sample_size) if sample_size % down_scale_factor != 0: - sample_size = ((audio_length_in_s * self.unet.sample_rate) // down_scale_factor + 1) * down_scale_factor + sample_size = ( + (audio_length_in_s * self.unet.config.sample_rate) // down_scale_factor + 1 + ) * down_scale_factor logger.info( - f"{audio_length_in_s} is increased to {sample_size / self.unet.sample_rate} so that it can be handled" - f" by the model. It will be cut to {original_sample_size / self.unet.sample_rate} after the denoising" + f"{audio_length_in_s} is increased to {sample_size / self.unet.config.sample_rate} so that it can be handled" + f" by the model. It will be cut to {original_sample_size / self.unet.config.sample_rate} after the denoising" " process." ) sample_size = int(sample_size) dtype = next(iter(self.unet.parameters())).dtype - shape = (batch_size, self.unet.in_channels, sample_size) + shape = (batch_size, self.unet.config.in_channels, sample_size) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 0e7f2258fa99..aaf53589b969 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -79,10 +79,15 @@ def __call__( """ # Sample gaussian noise to begin loop - if isinstance(self.unet.sample_size, int): - image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + if isinstance(self.unet.config.sample_size, int): + image_shape = ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size, + self.unet.config.sample_size, + ) else: - image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size) + image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index 549dbb29d5e7..b4290daf852c 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -67,10 +67,15 @@ def __call__( True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # Sample gaussian noise to begin loop - if isinstance(self.unet.sample_size, int): - image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + if isinstance(self.unet.config.sample_size, int): + image_shape = ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size, + self.unet.config.sample_size, + ) else: - image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size) + image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) if self.device.type == "mps": # randn does not work reproducibly on mps diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 623b456e52b5..3e4f9425b0f6 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -135,7 +135,7 @@ def __call__( prompt_embeds = self.bert(text_input.input_ids.to(self.device))[0] # get the initial random noise unless the user supplied it - latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py index 6887068f3443..ae620d325307 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py @@ -112,7 +112,7 @@ def __call__( height, width = image.shape[-2:] # in_channels should be 6: 3 for latents, 3 for low resolution image - latents_shape = (batch_size, self.unet.in_channels // 2, height, width) + latents_shape = (batch_size, self.unet.config.in_channels // 2, height, width) latents_dtype = next(self.unet.parameters()).dtype latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index dc0200feedb1..73c607a27187 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -73,7 +73,7 @@ def __call__( """ latents = randn_tensor( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), generator=generator, ) latents = latents.to(self.device) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index eec8df8a714b..06912a1464eb 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -506,6 +506,21 @@ def register_modules(self, **kwargs): # set models setattr(self, name, module) + def __setattr__(self, name: str, value: Any): + if hasattr(self, name) and hasattr(self.config, name): + # We need to overwrite the config if name exists in config + if isinstance(getattr(self.config, name), (tuple, list)): + if self.config[name][0] is not None: + class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__) + else: + class_library_tuple = (None, None) + + self.register_to_config(**{name: class_library_tuple}) + else: + self.register_to_config(**{name: value}) + + super().__setattr__(name, value) + def save_pretrained( self, save_directory: Union[str, os.PathLike], @@ -619,9 +634,11 @@ def module_is_offloaded(module): f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." ) - module_names, _, _ = self.extract_init_dict(dict(self.config)) + module_names, _ = self._get_signature_keys(self) + module_names = [m for m in module_names if hasattr(self, m)] + is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded - for name in module_names.keys(): + for name in module_names: module = getattr(self, name) if isinstance(module, torch.nn.Module): module.to(torch_device, torch_dtype) @@ -646,8 +663,10 @@ def device(self) -> torch.device: Returns: `torch.device`: The torch device on which the pipeline is located. """ - module_names, _, _ = self.extract_init_dict(dict(self.config)) - for name in module_names.keys(): + module_names, _ = self._get_signature_keys(self) + module_names = [m for m in module_names if hasattr(self, m)] + + for name in module_names: module = getattr(self, name) if isinstance(module, torch.nn.Module): return module.device @@ -1420,6 +1439,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): fn_recursive_set_mem_eff(child) module_names, _, _ = self.extract_init_dict(dict(self.config)) + module_names = [m for m in module_names if hasattr(self, m)] + for module_name in module_names: module = getattr(self, module_name) if isinstance(module, torch.nn.Module): @@ -1451,6 +1472,8 @@ def disable_attention_slicing(self): def set_attention_slice(self, slice_size: Optional[int]): module_names, _, _ = self.extract_init_dict(dict(self.config)) + module_names = [m for m in module_names if hasattr(self, m)] + for module_name in module_names: module = getattr(self, module_name) if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"): diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index 56fb72d3f4ff..361444079311 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -77,7 +77,7 @@ def __call__( # Sample gaussian noise to begin loop image = randn_tensor( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), generator=generator, device=self.device, ) diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index 69703fb8d82c..3d5374875d12 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -476,7 +476,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 066d1e99acaa..c0c2ee8b8aaa 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -247,7 +247,7 @@ def _generate( latents_shape = ( batch_size, - self.unet.in_channels, + self.unet.config.in_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py index 5af07ec8b9c4..df3e79a194f8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py @@ -283,7 +283,7 @@ def _generate( latents_shape = ( batch_size, - self.unet.in_channels, + self.unet.config.in_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index 2063238df27a..6a387af364b7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -268,7 +268,7 @@ def _generate( latents_shape = ( batch_size, - self.unet.in_channels, + self.unet.config.in_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 73b9178e3ab1..fcf44f02c731 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -649,7 +649,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index 46adb6967140..35351bae7116 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -855,7 +855,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index b8272a4ef3d6..12d21afbfeda 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -910,7 +910,7 @@ def __call__( timesteps = self.scheduler.timesteps # 6. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index 835fba10dee4..d543593fdbf5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -358,7 +358,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 7135b3e3ba31..277a4df0569d 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -561,7 +561,7 @@ def __call__( sigmas = sigmas.to(prompt_embeds.dtype) # 6. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py index d841bd8a2d26..b7ded03d529b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py @@ -722,7 +722,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index c47423bdee5b..d2d7330554ba 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -586,7 +586,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 6af923cb7743..e457ad2b3afc 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -929,7 +929,7 @@ def __call__( # 5. Generate the inverted noise from the input image or any other image # generated from the input prompt. - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index 2b08cf662bb4..063882284754 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -595,7 +595,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -701,7 +701,7 @@ def sag_masking(self, original_latents, attn_map, map_size, t, eps): # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf bh, hw1, hw2 = attn_map.shape b, latent_channel, latent_h, latent_w = original_latents.shape - h = self.unet.attention_head_dim + h = self.unet.config.attention_head_dim if isinstance(h, list): h = h[-1] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index ce41572e683c..fafb8d1d2800 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -877,7 +877,7 @@ def __call__( timesteps = self.scheduler.timesteps # 11. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) latents = self.prepare_latents( shape=shape, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index b9bf00bc7835..22b7280f3679 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -772,7 +772,7 @@ def __call__( timesteps = self.scheduler.timesteps # 6. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size=batch_size, num_channels_latents=num_channels_latents, diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index 850a0a4670e2..87e7b3e6c9eb 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -623,7 +623,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 1cbe78f0c964..6fc89e945604 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -606,7 +606,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 7d68f6f06ef6..6a3635613104 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -12,7 +12,7 @@ from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from ...models.transformer_2d import Transformer2DModel from ...models.unet_2d_condition import UNet2DConditionOutput -from ...utils import logging +from ...utils import deprecate, logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -504,6 +504,19 @@ def __init__( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) + @property + def in_channels(self): + deprecate( + "in_channels", + "1.0.0", + ( + "Accessing `in_channels` directly via unet.in_channels is deprecated. Please use" + " `unet.config.in_channels` instead" + ), + standard_warn=False, + ) + return self.config.in_channels + @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 9fb36db52df5..eaaf497f9c1d 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -22,7 +22,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, randn_tensor +from ..utils import BaseOutput, deprecate, randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @@ -167,6 +167,16 @@ def __init__( self.variance_type = variance_type + @property + def num_train_timesteps(self): + deprecate( + "num_train_timesteps", + "1.0.0", + "Accessing `num_train_timesteps` directly via scheduler.num_train_timesteps is deprecated. Please use `scheduler.config.num_train_timesteps instead`", + standard_warn=False, + ) + return self.config.num_train_timesteps + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index ffe3ec64f9ae..7aebda205e5b 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -183,7 +183,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ timesteps = ( - np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index c41fc7e16a4f..dfdfac3085d2 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -193,7 +193,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ timesteps = ( - np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 684a1eec7c1a..049e2b1dbd4d 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -190,8 +190,8 @@ def get_order_list(self, num_inference_steps: int) -> List[int]: the number of diffusion steps used when generating samples with a pre-trained model. """ steps = num_inference_steps - order = self.solver_order - if self.lower_order_final: + order = self.config.solver_order + if self.config.lower_order_final: if order == 3: if steps % 3 == 0: orders = [1, 2, 3] * (steps // 3 - 1) + [1, 2] + [1] @@ -227,7 +227,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ self.num_inference_steps = num_inference_steps timesteps = ( - np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 07e8b152b9d3..2cce68f7d962 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -195,7 +195,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ timesteps = ( - np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) diff --git a/tests/fixtures/custom_pipeline/pipeline.py b/tests/fixtures/custom_pipeline/pipeline.py index 9119ae30f42f..0bb10c3d5185 100644 --- a/tests/fixtures/custom_pipeline/pipeline.py +++ b/tests/fixtures/custom_pipeline/pipeline.py @@ -73,7 +73,7 @@ def __call__( # Sample gaussian noise to begin loop image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), generator=generator, ) image = image.to(self.device) diff --git a/tests/fixtures/custom_pipeline/what_ever.py b/tests/fixtures/custom_pipeline/what_ever.py index a8af08d3980a..494c5a1a4e95 100644 --- a/tests/fixtures/custom_pipeline/what_ever.py +++ b/tests/fixtures/custom_pipeline/what_ever.py @@ -73,7 +73,7 @@ def __call__( # Sample gaussian noise to begin loop image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), generator=generator, ) image = image.to(self.device) diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index b814f5f88a30..d3a3d5cfc9a0 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -116,7 +116,7 @@ def test_output_pretrained(self): if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) - num_features = model.in_channels + num_features = model.config.in_channels seq_len = 16 noise = torch.randn((1, seq_len, num_features)).permute( 0, 2, 1 @@ -264,7 +264,7 @@ def test_output_pretrained(self): if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) - num_features = value_function.in_channels + num_features = value_function.config.in_channels seq_len = 14 noise = torch.randn((1, seq_len, num_features)).permute( 0, 2, 1 diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 08cb03f55aaa..048030d98371 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -675,6 +675,25 @@ def test_download_from_git(self): image = pipeline("a prompt", num_inference_steps=2, output_type="np").images[0] assert image.shape == (512, 512, 3) + def test_save_pipeline_change_config(self): + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = DiffusionPipeline.from_pretrained(tmpdirname) + + assert pipe.scheduler.__class__.__name__ == "PNDMScheduler" + + # let's make sure that changing the scheduler is correctly reflected + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe.save_pretrained(tmpdirname) + pipe = DiffusionPipeline.from_pretrained(tmpdirname) + + assert pipe.scheduler.__class__.__name__ == "DPMSolverMultistepScheduler" + class PipelineFastTests(unittest.TestCase): def tearDown(self):