diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 2ab85ecee16d..1c5c3d7afd58 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -154,9 +154,12 @@ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], ret """ config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) - init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) + # Allow dtype to be specified on initialization + if "dtype" in unused_kwargs: + init_dict["dtype"] = unused_kwargs.pop("dtype") + model = cls(**init_dict) if return_unused_kwargs: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 7f068d7183ba..675b61266285 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -30,7 +30,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`FlaxSchedulerMixin`]): + scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): @@ -157,7 +157,7 @@ def __call__( self.unet.sample_size, ) if latents is None: - latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")