diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 4aee44c56c3e..fb0ce92cb61c 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -62,6 +62,7 @@ jobs: run: | python -m pip install -e .[quality,test] python -m pip install -U git+https://github.com/huggingface/transformers + python -m pip install git+https://github.com/huggingface/accelerate - name: Environment run: | @@ -134,6 +135,7 @@ jobs: ${CONDA_RUN} python -m pip install --upgrade pip ${CONDA_RUN} python -m pip install -e .[quality,test] ${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu + ${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate - name: Environment shell: arch -arch arm64 bash {0} @@ -157,4 +159,4 @@ jobs: uses: actions/upload-artifact@v2 with: name: torch_mps_test_reports - path: reports \ No newline at end of file + path: reports diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index 93bbdae388e6..082b12404a85 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -60,6 +60,7 @@ jobs: apt-get update && apt-get install libsndfile1-dev -y python -m pip install -e .[quality,test] python -m pip install -U git+https://github.com/huggingface/transformers + python -m pip install git+https://github.com/huggingface/accelerate - name: Environment run: | @@ -126,6 +127,7 @@ jobs: ${CONDA_RUN} python -m pip install --upgrade pip ${CONDA_RUN} python -m pip install -e .[quality,test] ${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu + ${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate ${CONDA_RUN} python -m pip install -U git+https://github.com/huggingface/transformers - name: Environment diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index df3a3bf0fdf2..2d4875b80ced 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -62,6 +62,7 @@ jobs: run: | python -m pip install -e .[quality,test] python -m pip install -U git+https://github.com/huggingface/transformers + python -m pip install git+https://github.com/huggingface/accelerate - name: Environment run: | @@ -130,6 +131,7 @@ jobs: - name: Install dependencies run: | python -m pip install -e .[quality,test,training] + python -m pip install git+https://github.com/huggingface/accelerate python -m pip install -U git+https://github.com/huggingface/transformers - name: Environment @@ -151,4 +153,4 @@ jobs: uses: actions/upload-artifact@v2 with: name: examples_test_reports - path: reports \ No newline at end of file + path: reports diff --git a/docs/README.md b/docs/README.md index b2d48dfee152..23b96fa82a85 100644 --- a/docs/README.md +++ b/docs/README.md @@ -155,9 +155,9 @@ adds a link to its documentation with this syntax: \[\`XXXClass\`\] or \[\`funct function to be in the main package. If you want to create a link to some internal class or function, you need to -provide its path. For instance: \[\`pipeline_utils.ImagePipelineOutput\`\]. This will be converted into a link with -`pipeline_utils.ImagePipelineOutput` in the description. To get rid of the path and only keep the name of the object you are -linking to in the description, add a ~: \[\`~pipeline_utils.ImagePipelineOutput\`\] will generate a link with `ImagePipelineOutput` in the description. +provide its path. For instance: \[\`pipelines.ImagePipelineOutput\`\]. This will be converted into a link with +`pipelines.ImagePipelineOutput` in the description. To get rid of the path and only keep the name of the object you are +linking to in the description, add a ~: \[\`~pipelines.ImagePipelineOutput\`\] will generate a link with `ImagePipelineOutput` in the description. The same works for methods so you can either use \[\`XXXClass.method\`\] or \[~\`XXXClass.method\`\]. diff --git a/docs/source/api/diffusion_pipeline.mdx b/docs/source/api/diffusion_pipeline.mdx index b037b4e26dc1..b5d56fb315d4 100644 --- a/docs/source/api/diffusion_pipeline.mdx +++ b/docs/source/api/diffusion_pipeline.mdx @@ -30,13 +30,17 @@ Any pipeline object can be saved locally with [`~DiffusionPipeline.save_pretrain ## DiffusionPipeline [[autodoc]] DiffusionPipeline - - from_pretrained - - save_pretrained - - to + - all + - __call__ - device - - components + - to ## ImagePipelineOutput By default diffusion pipelines return an object of class -[[autodoc]] pipeline_utils.ImagePipelineOutput +[[autodoc]] pipelines.ImagePipelineOutput + +## AudioPipelineOutput +By default diffusion pipelines return an object of class + +[[autodoc]] pipelines.AudioPipelineOutput diff --git a/docs/source/api/models.mdx b/docs/source/api/models.mdx index 678c902d4504..0a2183f08c32 100644 --- a/docs/source/api/models.mdx +++ b/docs/source/api/models.mdx @@ -41,13 +41,13 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module [[autodoc]] models.vae.DecoderOutput ## VQEncoderOutput -[[autodoc]] models.vae.VQEncoderOutput +[[autodoc]] models.vq_model.VQEncoderOutput ## VQModel [[autodoc]] VQModel ## AutoencoderKLOutput -[[autodoc]] models.vae.AutoencoderKLOutput +[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput ## AutoencoderKL [[autodoc]] AutoencoderKL diff --git a/docs/source/api/outputs.mdx b/docs/source/api/outputs.mdx index 010761fb2e4b..291a79756a16 100644 --- a/docs/source/api/outputs.mdx +++ b/docs/source/api/outputs.mdx @@ -25,7 +25,7 @@ pipeline = DDIMPipeline.from_pretrained("google/ddpm-cifar10-32") outputs = pipeline() ``` -The `outputs` object is a [`~pipeline_utils.ImagePipelineOutput`], as we can see in the +The `outputs` object is a [`~pipelines.ImagePipelineOutput`], as we can see in the documentation of that class below, it means it has an image attribute. You can access each attribute as you would usually do, and if that attribute has not been returned by the model, you will get `None`: diff --git a/docs/source/api/pipelines/alt_diffusion.mdx b/docs/source/api/pipelines/alt_diffusion.mdx index 8d7d795d7633..0f497a390a7c 100644 --- a/docs/source/api/pipelines/alt_diffusion.mdx +++ b/docs/source/api/pipelines/alt_diffusion.mdx @@ -69,15 +69,15 @@ If you want to use all possible use cases in a single `DiffusionPipeline` we rec ## AltDiffusionPipelineOutput [[autodoc]] pipelines.alt_diffusion.AltDiffusionPipelineOutput + - all + - __call__ ## AltDiffusionPipeline [[autodoc]] AltDiffusionPipeline + - all - __call__ - - enable_attention_slicing - - disable_attention_slicing ## AltDiffusionImg2ImgPipeline [[autodoc]] AltDiffusionImg2ImgPipeline + - all - __call__ - - enable_attention_slicing - - disable_attention_slicing diff --git a/docs/source/api/pipelines/audio_diffusion.mdx b/docs/source/api/pipelines/audio_diffusion.mdx index bafdbef28140..ec9b1fb2d304 100644 --- a/docs/source/api/pipelines/audio_diffusion.mdx +++ b/docs/source/api/pipelines/audio_diffusion.mdx @@ -91,12 +91,8 @@ display(Audio(output.audios[0], rate=pipe.mel.get_sample_rate())) ## AudioDiffusionPipeline [[autodoc]] AudioDiffusionPipeline - - __call__ - - encode - - slerp - + - all + - __call__ ## Mel [[autodoc]] Mel - - audio_slice_to_image - - image_to_audio diff --git a/docs/source/api/pipelines/cycle_diffusion.mdx b/docs/source/api/pipelines/cycle_diffusion.mdx index b5c078ad9466..70986bd39a3d 100644 --- a/docs/source/api/pipelines/cycle_diffusion.mdx +++ b/docs/source/api/pipelines/cycle_diffusion.mdx @@ -96,4 +96,5 @@ image.save("black_to_blue.png") ## CycleDiffusionPipeline [[autodoc]] CycleDiffusionPipeline + - all - __call__ diff --git a/docs/source/api/pipelines/dance_diffusion.mdx b/docs/source/api/pipelines/dance_diffusion.mdx index 4d969bf6f032..8264de7db603 100644 --- a/docs/source/api/pipelines/dance_diffusion.mdx +++ b/docs/source/api/pipelines/dance_diffusion.mdx @@ -30,4 +30,5 @@ The original codebase of this implementation can be found [here](https://github. ## DanceDiffusionPipeline [[autodoc]] DanceDiffusionPipeline - - __call__ + - all + - __call__ diff --git a/docs/source/api/pipelines/ddim.mdx b/docs/source/api/pipelines/ddim.mdx index a7a5421b36fe..b1dfa3b056a8 100644 --- a/docs/source/api/pipelines/ddim.mdx +++ b/docs/source/api/pipelines/ddim.mdx @@ -32,4 +32,5 @@ For questions, feel free to contact the author on [tsong.me](https://tsong.me/). ## DDIMPipeline [[autodoc]] DDIMPipeline - - __call__ + - all + - __call__ diff --git a/docs/source/api/pipelines/ddpm.mdx b/docs/source/api/pipelines/ddpm.mdx index c6d8a6f28660..92cee580d152 100644 --- a/docs/source/api/pipelines/ddpm.mdx +++ b/docs/source/api/pipelines/ddpm.mdx @@ -33,4 +33,5 @@ The original codebase of this paper can be found [here](https://github.com/hojon # DDPMPipeline [[autodoc]] DDPMPipeline - - __call__ + - all + - __call__ diff --git a/docs/source/api/pipelines/latent_diffusion.mdx b/docs/source/api/pipelines/latent_diffusion.mdx index 370d014f5a10..475957d93cd8 100644 --- a/docs/source/api/pipelines/latent_diffusion.mdx +++ b/docs/source/api/pipelines/latent_diffusion.mdx @@ -40,8 +40,10 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff ## LDMTextToImagePipeline [[autodoc]] LDMTextToImagePipeline - - __call__ + - all + - __call__ ## LDMSuperResolutionPipeline [[autodoc]] LDMSuperResolutionPipeline - - __call__ + - all + - __call__ diff --git a/docs/source/api/pipelines/latent_diffusion_uncond.mdx b/docs/source/api/pipelines/latent_diffusion_uncond.mdx index 0a5b20cd4a1c..03f1f31cee5d 100644 --- a/docs/source/api/pipelines/latent_diffusion_uncond.mdx +++ b/docs/source/api/pipelines/latent_diffusion_uncond.mdx @@ -38,4 +38,5 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff ## LDMPipeline [[autodoc]] LDMPipeline - - __call__ + - all + - __call__ diff --git a/docs/source/api/pipelines/paint_by_example.mdx b/docs/source/api/pipelines/paint_by_example.mdx index e40b3453edf4..91b936d98ac0 100644 --- a/docs/source/api/pipelines/paint_by_example.mdx +++ b/docs/source/api/pipelines/paint_by_example.mdx @@ -69,5 +69,6 @@ image ``` ## PaintByExamplePipeline -[[autodoc]] pipelines.paint_by_example.pipeline_paint_by_example.PaintByExamplePipeline - - __call__ +[[autodoc]] PaintByExamplePipeline + - all + - __call__ diff --git a/docs/source/api/pipelines/pndm.mdx b/docs/source/api/pipelines/pndm.mdx index 89930f4d4f8f..824a927d8bc3 100644 --- a/docs/source/api/pipelines/pndm.mdx +++ b/docs/source/api/pipelines/pndm.mdx @@ -30,6 +30,6 @@ The original codebase can be found [here](https://github.com/luping-liu/PNDM). ## PNDMPipeline -[[autodoc]] pipelines.pndm.pipeline_pndm.PNDMPipeline - - __call__ - +[[autodoc]] PNDMPipeline + - all + - __call__ diff --git a/docs/source/api/pipelines/repaint.mdx b/docs/source/api/pipelines/repaint.mdx index ce262daffaeb..d0a3a6875b24 100644 --- a/docs/source/api/pipelines/repaint.mdx +++ b/docs/source/api/pipelines/repaint.mdx @@ -72,6 +72,6 @@ inpainted_image = output.images[0] ``` ## RePaintPipeline -[[autodoc]] pipelines.repaint.pipeline_repaint.RePaintPipeline - - __call__ - +[[autodoc]] RePaintPipeline + - all + - __call__ diff --git a/docs/source/api/pipelines/score_sde_ve.mdx b/docs/source/api/pipelines/score_sde_ve.mdx index 3d6619c2591f..7a5d7ee83aa5 100644 --- a/docs/source/api/pipelines/score_sde_ve.mdx +++ b/docs/source/api/pipelines/score_sde_ve.mdx @@ -32,5 +32,5 @@ This pipeline implements the Variance Expanding (VE) variant of the method. ## ScoreSdeVePipeline [[autodoc]] ScoreSdeVePipeline - - __call__ - + - all + - __call__ diff --git a/docs/source/api/pipelines/stable_diffusion.mdx b/docs/source/api/pipelines/stable_diffusion.mdx index f16c2cba0274..fe28155600ce 100644 --- a/docs/source/api/pipelines/stable_diffusion.mdx +++ b/docs/source/api/pipelines/stable_diffusion.mdx @@ -73,16 +73,18 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca ## StableDiffusionPipeline [[autodoc]] StableDiffusionPipeline + - all - __call__ - enable_attention_slicing - disable_attention_slicing - - enable_vae_slicing - - disable_vae_slicing - enable_xformers_memory_efficient_attention - disable_xformers_memory_efficient_attention + + ## StableDiffusionImg2ImgPipeline [[autodoc]] StableDiffusionImg2ImgPipeline + - all - __call__ - enable_attention_slicing - disable_attention_slicing @@ -91,6 +93,7 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca ## StableDiffusionInpaintPipeline [[autodoc]] StableDiffusionInpaintPipeline + - all - __call__ - enable_attention_slicing - disable_attention_slicing @@ -99,6 +102,7 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca ## StableDiffusionDepth2ImgPipeline [[autodoc]] StableDiffusionDepth2ImgPipeline + - all - __call__ - enable_attention_slicing - disable_attention_slicing @@ -107,15 +111,16 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca ## StableDiffusionImageVariationPipeline [[autodoc]] StableDiffusionImageVariationPipeline + - all - __call__ - enable_attention_slicing - disable_attention_slicing - enable_xformers_memory_efficient_attention - disable_xformers_memory_efficient_attention - ## StableDiffusionUpscalePipeline [[autodoc]] StableDiffusionUpscalePipeline + - all - __call__ - enable_attention_slicing - disable_attention_slicing diff --git a/docs/source/api/pipelines/stable_diffusion_safe.mdx b/docs/source/api/pipelines/stable_diffusion_safe.mdx index 81fc59d3928c..c700b9c9f69e 100644 --- a/docs/source/api/pipelines/stable_diffusion_safe.mdx +++ b/docs/source/api/pipelines/stable_diffusion_safe.mdx @@ -81,10 +81,10 @@ To use a different scheduler, you can either change it via the [`ConfigMixin.fro ## StableDiffusionSafePipelineOutput [[autodoc]] pipelines.stable_diffusion_safe.StableDiffusionSafePipelineOutput + - all + - __call__ ## StableDiffusionPipelineSafe [[autodoc]] StableDiffusionPipelineSafe + - all - __call__ - - enable_attention_slicing - - disable_attention_slicing - diff --git a/docs/source/api/pipelines/stochastic_karras_ve.mdx b/docs/source/api/pipelines/stochastic_karras_ve.mdx index de762cbda002..ab185ec20d6c 100644 --- a/docs/source/api/pipelines/stochastic_karras_ve.mdx +++ b/docs/source/api/pipelines/stochastic_karras_ve.mdx @@ -32,4 +32,5 @@ This pipeline implements the Stochastic sampling tailored to the Variance-Expand ## KarrasVePipeline [[autodoc]] KarrasVePipeline - - __call__ + - all + - __call__ diff --git a/docs/source/api/pipelines/unclip.mdx b/docs/source/api/pipelines/unclip.mdx index 0d2e17601261..87d44adc0d76 100644 --- a/docs/source/api/pipelines/unclip.mdx +++ b/docs/source/api/pipelines/unclip.mdx @@ -24,8 +24,14 @@ The unCLIP model in diffusers comes from kakaobrain's karlo and the original cod | Pipeline | Tasks | Colab |---|---|:---:| | [pipeline_unclip.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/unclip/pipeline_unclip.py) | *Text-to-Image Generation* | - | +| [pipeline_unclip_image_variation.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py) | *Image-Guided Image Generation* | - | ## UnCLIPPipeline -[[autodoc]] pipelines.unclip.pipeline_unclip.UnCLIPPipeline - - __call__ \ No newline at end of file +[[autodoc]] UnCLIPPipeline + - all + - __call__ + +[[autodoc]] UnCLIPImageVariationPipeline + - all + - __call__ diff --git a/docs/source/api/pipelines/versatile_diffusion.mdx b/docs/source/api/pipelines/versatile_diffusion.mdx index f557c5b0aac8..80da6768bfcf 100644 --- a/docs/source/api/pipelines/versatile_diffusion.mdx +++ b/docs/source/api/pipelines/versatile_diffusion.mdx @@ -56,18 +56,15 @@ To use a different scheduler, you can either change it via the [`ConfigMixin.fro ## VersatileDiffusionTextToImagePipeline [[autodoc]] VersatileDiffusionTextToImagePipeline + - all - __call__ - - enable_attention_slicing - - disable_attention_slicing ## VersatileDiffusionImageVariationPipeline [[autodoc]] VersatileDiffusionImageVariationPipeline + - all - __call__ - - enable_attention_slicing - - disable_attention_slicing ## VersatileDiffusionDualGuidedPipeline [[autodoc]] VersatileDiffusionDualGuidedPipeline + - all - __call__ - - enable_attention_slicing - - disable_attention_slicing diff --git a/docs/source/api/pipelines/vq_diffusion.mdx b/docs/source/api/pipelines/vq_diffusion.mdx index 92cc903eee79..459c65293589 100644 --- a/docs/source/api/pipelines/vq_diffusion.mdx +++ b/docs/source/api/pipelines/vq_diffusion.mdx @@ -30,5 +30,6 @@ The original codebase can be found [here](https://github.com/microsoft/VQ-Diffus ## VQDiffusionPipeline -[[autodoc]] pipelines.vq_diffusion.pipeline_vq_diffusion.VQDiffusionPipeline - - __call__ +[[autodoc]] VQDiffusionPipeline + - all + - __call__ diff --git a/examples/community/README.md b/examples/community/README.md index ddb0b8ce9389..a848f74f2a29 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -25,6 +25,7 @@ If a community doesn't work as expected, please open an issue and ping the autho | K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | | Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) | +MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) | @@ -815,6 +816,50 @@ plt.title('Stable Diffusion v1.4') plt.axis('off') plt.show() +``` + +As a result, you can look at a grid of all 4 generated images being shown together, that captures a difference the advancement of the training between the 4 checkpoints. + +### Magic Mix + +Implementation of the [MagicMix: Semantic Mixing with Diffusion Models](https://arxiv.org/abs/2210.16056) paper. This is a Diffusion Pipeline for semantic mixing of an image and a text prompt to create a new concept while preserving the spatial layout and geometry of the subject in the image. The pipeline takes an image that provides the layout semantics and a prompt that provides the content semantics for the mixing process. + +There are 3 parameters for the method- +- `mix_factor`: It is the interpolation constant used in the layout generation phase. The greater the value of `mix_factor`, the greater the influence of the prompt on the layout generation process. +- `kmax` and `kmin`: These determine the range for the layout and content generation process. A higher value of kmax results in loss of more information about the layout of the original image and a higher value of kmin results in more steps for content generation process. + +Here is an example usage- + ```python +from diffusers import DiffusionPipeline, DDIMScheduler +from PIL import Image + +pipe = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + custom_pipeline="magic_mix", + scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler"), +).to('cuda') + +img = Image.open('phone.jpg') +mix_img = pipe( + img, + prompt = 'bed', + kmin = 0.3, + kmax = 0.5, + mix_factor = 0.5, + ) +mix_img.save('phone_bed_mix.jpg') +``` +The `mix_img` is a PIL image that can be saved locally or displayed directly in a google colab. Generated image is a mix of the layout semantics of the given image and the content semantics of the prompt. + +E.g. the above script generates the following image: + +`phone.jpg` + +![206903102-34e79b9f-9ed2-4fac-bb38-82871343c655](https://user-images.githubusercontent.com/59410571/209578593-141467c7-d831-4792-8b9a-b17dc5e47816.jpg) + +`phone_bed_mix.jpg` + +![206903104-913a671d-ef53-4ae4-919d-64c3059c8f67](https://user-images.githubusercontent.com/59410571/209578602-70f323fa-05b7-4dd6-b055-e40683e37914.jpg) -As a result, you can look at a grid of all 4 generated images being shown together, that captures a difference the advancement of the training between the 4 checkpoints. \ No newline at end of file +For more example generations check out this [demo notebook](https://github.com/daspartho/MagicMix/blob/main/demo.ipynb). diff --git a/examples/community/bit_diffusion.py b/examples/community/bit_diffusion.py index 956e25a7e5c5..b27b67c97a36 100644 --- a/examples/community/bit_diffusion.py +++ b/examples/community/bit_diffusion.py @@ -2,8 +2,7 @@ import torch -from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel -from diffusers.pipeline_utils import ImagePipelineOutput +from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DConditionModel from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput from einops import rearrange, reduce diff --git a/examples/community/checkpoint_merger.py b/examples/community/checkpoint_merger.py index 982da6646a76..b6e418221ff7 100644 --- a/examples/community/checkpoint_merger.py +++ b/examples/community/checkpoint_merger.py @@ -5,13 +5,7 @@ import torch from diffusers import DiffusionPipeline, __version__ -from diffusers.pipeline_utils import ( - CONFIG_NAME, - DIFFUSERS_CACHE, - ONNX_WEIGHTS_NAME, - SCHEDULER_CONFIG_NAME, - WEIGHTS_NAME, -) +from diffusers.utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, WEIGHTS_NAME from huggingface_hub import snapshot_download diff --git a/examples/community/composable_stable_diffusion.py b/examples/community/composable_stable_diffusion.py index 1d7b63711ccd..7ee997750b04 100644 --- a/examples/community/composable_stable_diffusion.py +++ b/examples/community/composable_stable_diffusion.py @@ -17,14 +17,10 @@ import torch -from diffusers.utils import is_accelerate_available -from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -from ...configuration_utils import FrozenDict -from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline -from ...schedulers import ( +from diffusers import DiffusionPipeline +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, @@ -32,6 +28,10 @@ LMSDiscreteScheduler, PNDMScheduler, ) +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + from ...utils import deprecate, logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker diff --git a/examples/community/imagic_stable_diffusion.py b/examples/community/imagic_stable_diffusion.py index 2488675f59c8..8b5bb7d060d2 100644 --- a/examples/community/imagic_stable_diffusion.py +++ b/examples/community/imagic_stable_diffusion.py @@ -12,8 +12,8 @@ import PIL from accelerate import Accelerator +from diffusers import DiffusionPipeline from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler diff --git a/examples/community/img2img_inpainting.py b/examples/community/img2img_inpainting.py index 3fa7db13a482..cb8071e831b0 100644 --- a/examples/community/img2img_inpainting.py +++ b/examples/community/img2img_inpainting.py @@ -5,9 +5,9 @@ import torch import PIL +from diffusers import DiffusionPipeline from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler diff --git a/examples/community/interpolate_stable_diffusion.py b/examples/community/interpolate_stable_diffusion.py index 4d7a73f5ba69..9087dd7d2ca6 100644 --- a/examples/community/interpolate_stable_diffusion.py +++ b/examples/community/interpolate_stable_diffusion.py @@ -6,9 +6,9 @@ import numpy as np import torch +from diffusers import DiffusionPipeline from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler diff --git a/examples/community/lpw_stable_diffusion_onnx.py b/examples/community/lpw_stable_diffusion_onnx.py index 58165dbd2a4c..f37bd6e816a0 100644 --- a/examples/community/lpw_stable_diffusion_onnx.py +++ b/examples/community/lpw_stable_diffusion_onnx.py @@ -7,8 +7,7 @@ import diffusers import PIL -from diffusers import OnnxStableDiffusionPipeline, SchedulerMixin -from diffusers.onnx_utils import OnnxRuntimeModel +from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import deprecate, logging from packaging import version @@ -16,7 +15,7 @@ try: - from diffusers.onnx_utils import ORT_TO_NP_TYPE + from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE except ImportError: ORT_TO_NP_TYPE = { "tensor(bool)": np.bool_, diff --git a/examples/community/magic_mix.py b/examples/community/magic_mix.py new file mode 100644 index 000000000000..d67aec781c36 --- /dev/null +++ b/examples/community/magic_mix.py @@ -0,0 +1,152 @@ +from typing import Union + +import torch + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DiffusionPipeline, + LMSDiscreteScheduler, + PNDMScheduler, + UNet2DConditionModel, +) +from PIL import Image +from torchvision import transforms as tfms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + + +class MagicMixPipeline(DiffusionPipeline): + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler], + ): + super().__init__() + + self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + + # convert PIL image to latents + def encode(self, img): + with torch.no_grad(): + latent = self.vae.encode(tfms.ToTensor()(img).unsqueeze(0).to(self.device) * 2 - 1) + latent = 0.18215 * latent.latent_dist.sample() + return latent + + # convert latents to PIL image + def decode(self, latent): + latent = (1 / 0.18215) * latent + with torch.no_grad(): + img = self.vae.decode(latent).sample + img = (img / 2 + 0.5).clamp(0, 1) + img = img.detach().cpu().permute(0, 2, 3, 1).numpy() + img = (img * 255).round().astype("uint8") + return Image.fromarray(img[0]) + + # convert prompt into text embeddings, also unconditional embeddings + def prep_text(self, prompt): + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_embedding = self.text_encoder(text_input.input_ids.to(self.device))[0] + + uncond_input = self.tokenizer( + "", + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_embedding = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + return torch.cat([uncond_embedding, text_embedding]) + + def __call__( + self, + img: Image.Image, + prompt: str, + kmin: float = 0.3, + kmax: float = 0.6, + mix_factor: float = 0.5, + seed: int = 42, + steps: int = 50, + guidance_scale: float = 7.5, + ) -> Image.Image: + tmin = steps - int(kmin * steps) + tmax = steps - int(kmax * steps) + + text_embeddings = self.prep_text(prompt) + + self.scheduler.set_timesteps(steps) + + width, height = img.size + encoded = self.encode(img) + + torch.manual_seed(seed) + noise = torch.randn( + (1, self.unet.in_channels, height // 8, width // 8), + ).to(self.device) + + latents = self.scheduler.add_noise( + encoded, + noise, + timesteps=self.scheduler.timesteps[tmax], + ) + + input = torch.cat([latents] * 2) + + input = self.scheduler.scale_model_input(input, self.scheduler.timesteps[tmax]) + + with torch.no_grad(): + pred = self.unet( + input, + self.scheduler.timesteps[tmax], + encoder_hidden_states=text_embeddings, + ).sample + + pred_uncond, pred_text = pred.chunk(2) + pred = pred_uncond + guidance_scale * (pred_text - pred_uncond) + + latents = self.scheduler.step(pred, self.scheduler.timesteps[tmax], latents).prev_sample + + for i, t in enumerate(tqdm(self.scheduler.timesteps)): + if i > tmax: + if i < tmin: # layout generation phase + orig_latents = self.scheduler.add_noise( + encoded, + noise, + timesteps=t, + ) + + input = (mix_factor * latents) + ( + 1 - mix_factor + ) * orig_latents # interpolating between layout noise and conditionally generated noise to preserve layout sematics + input = torch.cat([input] * 2) + + else: # content generation phase + input = torch.cat([latents] * 2) + + input = self.scheduler.scale_model_input(input, t) + + with torch.no_grad(): + pred = self.unet( + input, + t, + encoder_hidden_states=text_embeddings, + ).sample + + pred_uncond, pred_text = pred.chunk(2) + pred = pred_uncond + guidance_scale * (pred_text - pred_uncond) + + latents = self.scheduler.step(pred, t, latents).prev_sample + + return self.decode(latents) diff --git a/examples/community/multilingual_stable_diffusion.py b/examples/community/multilingual_stable_diffusion.py index 19974d6df08b..bcadc8037395 100644 --- a/examples/community/multilingual_stable_diffusion.py +++ b/examples/community/multilingual_stable_diffusion.py @@ -3,9 +3,9 @@ import torch +from diffusers import DiffusionPipeline from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler diff --git a/examples/community/sd_text2img_k_diffusion.py b/examples/community/sd_text2img_k_diffusion.py index 6b15674ea0fa..d5b32024105d 100755 --- a/examples/community/sd_text2img_k_diffusion.py +++ b/examples/community/sd_text2img_k_diffusion.py @@ -18,8 +18,7 @@ import torch -from diffusers import LMSDiscreteScheduler -from diffusers.pipeline_utils import DiffusionPipeline +from diffusers import DiffusionPipeline, LMSDiscreteScheduler from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import is_accelerate_available, logging from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser diff --git a/examples/community/seed_resize_stable_diffusion.py b/examples/community/seed_resize_stable_diffusion.py index 92cd1c04f9f3..a3d17441587f 100644 --- a/examples/community/seed_resize_stable_diffusion.py +++ b/examples/community/seed_resize_stable_diffusion.py @@ -6,8 +6,8 @@ import torch +from diffusers import DiffusionPipeline from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler diff --git a/examples/community/text_inpainting.py b/examples/community/text_inpainting.py index f02d449fbd1d..8a12044d4c1f 100644 --- a/examples/community/text_inpainting.py +++ b/examples/community/text_inpainting.py @@ -3,9 +3,9 @@ import torch import PIL +from diffusers import DiffusionPipeline from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler diff --git a/examples/community/wildcard_stable_diffusion.py b/examples/community/wildcard_stable_diffusion.py index ee45e62a237c..6ba574cd2144 100644 --- a/examples/community/wildcard_stable_diffusion.py +++ b/examples/community/wildcard_stable_diffusion.py @@ -7,9 +7,9 @@ import torch +from diffusers import DiffusionPipeline from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 2bbdb7a5da8f..2858c04c48b0 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -317,4 +317,7 @@ python train_dreambooth_flax.py \ --max_train_steps=800 ``` -You can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint). \ No newline at end of file +### Training with xformers: +You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation. + +You can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint). diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 122d346ff5ce..30f5e0ccae0b 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -247,7 +247,20 @@ def parse_args(input_args=None): " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) if input_args is not None: args = parser.parse_args(input_args) @@ -433,6 +446,12 @@ def main(args): if cur_class_images < args.num_class_images: torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -516,14 +535,11 @@ def main(args): revision=args.revision, ) - if is_xformers_available(): - try: + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() - except Exception as e: - logger.warning( - "Could not enable memory efficient attention. Make sure xformers is installed" - f" correctly and a GPU is available: {e}" - ) + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") vae.requires_grad_(False) if not args.train_text_encoder: @@ -716,7 +732,7 @@ def main(args): target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index 407578e3b717..e98e136a4b31 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -160,3 +160,6 @@ python train_text_to_image_flax.py \ --max_grad_norm=1 \ --output_dir="sd-pokemon-model" ``` + +### Training with xformers: +You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation. diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 6c45ee0b1b65..986a57b75779 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1,4 +1,5 @@ import argparse +import copy import logging import math import os @@ -11,6 +12,9 @@ import torch.nn.functional as F import torch.utils.checkpoint +import datasets +import diffusers +import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed @@ -28,7 +32,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.10.0.dev0") -logger = get_logger(__name__) +logger = get_logger(__name__, log_level="INFO") def parse_args(): @@ -171,7 +175,25 @@ def parse_args(): parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") @@ -234,6 +256,9 @@ def parse_args(): ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -244,6 +269,10 @@ def parse_args(): if args.dataset_name is None and args.train_data_dir is None: raise ValueError("Need either a dataset name or a training folder.") + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + return args @@ -272,27 +301,24 @@ def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters] + self.collected_params = None + self.decay = decay self.optimization_step = 0 - def get_decay(self, optimization_step): - """ - Compute the decay factor for the exponential moving average. - """ - value = (1 + optimization_step) / (10 + optimization_step) - return 1 - min(self.decay, value) - @torch.no_grad() def step(self, parameters): parameters = list(parameters) self.optimization_step += 1 - self.decay = self.get_decay(self.optimization_step) + + # Compute the decay factor for the exponential moving average. + value = (1 + self.optimization_step) / (10 + self.optimization_step) + one_minus_decay = 1 - min(self.decay, value) for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: - tmp = self.decay * (s_param - param) - s_param.sub_(tmp) + s_param.sub_(one_minus_decay * (s_param - param)) else: s_param.copy_(param) @@ -324,6 +350,55 @@ def to(self, device=None, dtype=None) -> None: for p in self.shadow_params ] + def state_dict(self) -> dict: + r""" + Returns the state of the ExponentialMovingAverage as a dict. + This method is used by accelerate during checkpointing to save the ema state dict. + """ + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "optimization_step": self.optimization_step, + "shadow_params": self.shadow_params, + "collected_params": self.collected_params, + } + + def load_state_dict(self, state_dict: dict) -> None: + r""" + Loads the ExponentialMovingAverage state. + This method is used by accelerate during checkpointing to save the ema state dict. + Args: + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + + self.decay = state_dict["decay"] + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.optimization_step = state_dict["optimization_step"] + if not isinstance(self.optimization_step, int): + raise ValueError("Invalid optimization_step") + + self.shadow_params = state_dict["shadow_params"] + if not isinstance(self.shadow_params, list): + raise ValueError("shadow_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): + raise ValueError("shadow_params must all be Tensors") + + self.collected_params = state_dict["collected_params"] + if self.collected_params is not None: + if not isinstance(self.collected_params, list): + raise ValueError("collected_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.collected_params): + raise ValueError("collected_params must all be Tensors") + if len(self.collected_params) != len(self.shadow_params): + raise ValueError("collected_params and shadow_params must have the same length") + def main(): args = parse_args() @@ -341,6 +416,15 @@ def main(): datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: @@ -363,42 +447,44 @@ def main(): elif args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) - # Load models and create wrapper for stable diffusion + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision ) text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder", - revision=args.revision, - ) - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="vae", - revision=args.revision, + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, + args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision ) - if is_xformers_available(): - try: - unet.enable_xformers_memory_efficient_attention() - except Exception as e: - logger.warning( - "Could not enable memory efficient attention. Make sure xformers is installed" - f" correctly and a GPU is available: {e}" - ) - # Freeze vae and text_encoder vae.requires_grad_(False) text_encoder.requires_grad_(False) + # Create EMA for the unet. + if args.use_ema: + ema_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + ema_unet = EMAModel(ema_unet.parameters()) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes @@ -424,7 +510,6 @@ def main(): weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") # Get the datasets: you can either provide your own training and evaluation files (see below) # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). @@ -487,13 +572,14 @@ def tokenize_captions(examples, is_train=True): raise ValueError( f"Caption column `{caption_column}` should contain either strings or lists of strings." ) - inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True) - input_ids = inputs.input_ids - return input_ids + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids train_transforms = transforms.Compose( [ - transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), transforms.ToTensor(), @@ -505,7 +591,6 @@ def preprocess_train(examples): images = [image.convert("RGB") for image in examples[image_column]] examples["pixel_values"] = [train_transforms(image) for image in images] examples["input_ids"] = tokenize_captions(examples) - return examples with accelerator.main_process_first(): @@ -517,13 +602,8 @@ def preprocess_train(examples): def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - input_ids = [example["input_ids"] for example in examples] - padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt") - return { - "pixel_values": pixel_values, - "input_ids": padded_tokens.input_ids, - "attention_mask": padded_tokens.attention_mask, - } + input_ids = torch.stack([example["input_ids"] for example in examples]) + return {"pixel_values": pixel_values, "input_ids": input_ids} train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size @@ -546,23 +626,22 @@ def collate_fn(examples): unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler ) - accelerator.register_for_checkpointing(lr_scheduler) + if args.use_ema: + accelerator.register_for_checkpointing(ema_unet) + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - # Move text_encode and vae to gpu. - # For mixed precision training we cast the text_encoder and vae weights to half-precision - # as these models are only used for inference, keeping weights in full precision is not required. + # Move text_encode and vae to gpu and cast to weight_dtype text_encoder.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) - - # Create EMA for the unet. if args.use_ema: - ema_unet = EMAModel(unet.parameters()) + ema_unet.to(accelerator.device) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index b3379226f243..4554cdd082e1 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -333,7 +333,7 @@ def tokenize_captions(examples, is_train=True): train_transforms = transforms.Compose( [ - transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), transforms.ToTensor(), diff --git a/examples/textual_inversion/README.md b/examples/textual_inversion/README.md index a2bde75b51de..3a7c96be69fb 100644 --- a/examples/textual_inversion/README.md +++ b/examples/textual_inversion/README.md @@ -124,3 +124,6 @@ python textual_inversion_flax.py \ --output_dir="textual_inversion_cat" ``` It should be at least 70% faster than the PyTorch script with the same configuration. + +### Training with xformers: +You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation. diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 7fbca761bdc8..2a765e47a20b 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -1,5 +1,4 @@ import argparse -import itertools import math import os import random @@ -147,6 +146,11 @@ def parse_args(): default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) parser.add_argument( "--learning_rate", type=float, @@ -222,6 +226,9 @@ def parse_args(): ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -380,11 +387,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: return f"{organization}/{model_id}" -def freeze_params(params): - for param in params: - param.requires_grad = False - - def main(): args = parse_args() logging_dir = os.path.join(args.output_dir, args.logging_dir) @@ -457,14 +459,15 @@ def main(): revision=args.revision, ) - if is_xformers_available(): - try: + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + unet.enable_gradient_checkpointing() + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() - except Exception as e: - logger.warning( - "Could not enable memory efficient attention. Make sure xformers is installed" - f" correctly and a GPU is available: {e}" - ) + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) @@ -474,15 +477,12 @@ def main(): token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] # Freeze vae and unet - freeze_params(vae.parameters()) - freeze_params(unet.parameters()) + vae.requires_grad_(False) + unet.requires_grad_(False) # Freeze all parameters except for the token embeddings in text encoder - params_to_freeze = itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - text_encoder.text_model.embeddings.position_embedding.parameters(), - ) - freeze_params(params_to_freeze) + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) if args.scale_lr: args.learning_rate = ( @@ -541,9 +541,10 @@ def main(): unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) - # Keep vae and unet in eval model as we don't train these - vae.eval() - unet.eval() + # Keep unet in train mode if we are using gradient checkpointing to save memory. + # The dropout is 0 so it doesn't matter if we are in eval or train mode. + if args.gradient_checkpointing: + unet.train() # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -609,12 +610,11 @@ def main(): latents = latents * 0.18215 # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device).to(dtype=weight_dtype) + noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device - ).long() + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -634,7 +634,8 @@ def main(): else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + accelerator.backward(loss) optimizer.step() @@ -669,8 +670,7 @@ def main(): if global_step >= args.max_train_steps: break - accelerator.wait_for_everyone() - + accelerator.wait_for_everyone() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: if args.push_to_hub and args.only_save_embeds: diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index 0414a0e8ad6a..1f4204495482 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -848,12 +848,17 @@ def convert_open_clip_checkpoint(checkpoint): ), ) parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") args = parser.parse_args() image_size = args.image_size prediction_type = args.prediction_type - checkpoint = torch.load(args.checkpoint_path) + if args.device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(args.checkpoint_path, map_location=device) + else: + checkpoint = torch.load(args.checkpoint_path, map_location=args.device) # Sometimes models don't have the global_step item if "global_step" in checkpoint: diff --git a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py index 26d3d5618f88..7a2a682d3416 100644 --- a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py +++ b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py @@ -21,8 +21,7 @@ from torch.onnx import export import onnx -from diffusers import OnnxStableDiffusionPipeline, StableDiffusionPipeline -from diffusers.onnx_utils import OnnxRuntimeModel +from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline from packaging import version diff --git a/scripts/convert_unclip_txt2img_to_image_variation.py b/scripts/convert_unclip_txt2img_to_image_variation.py new file mode 100644 index 000000000000..d228a537ed4c --- /dev/null +++ b/scripts/convert_unclip_txt2img_to_image_variation.py @@ -0,0 +1,40 @@ +import argparse + +from diffusers import UnCLIPImageVariationPipeline, UnCLIPPipeline +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + parser.add_argument( + "--txt2img_unclip", + default="kakaobrain/karlo-v1-alpha", + type=str, + required=False, + help="The pretrained txt2img unclip.", + ) + + args = parser.parse_args() + + txt2img = UnCLIPPipeline.from_pretrained(args.txt2img_unclip) + + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + + img2img = UnCLIPImageVariationPipeline( + decoder=txt2img.decoder, + text_encoder=txt2img.text_encoder, + tokenizer=txt2img.tokenizer, + text_proj=txt2img.text_proj, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + super_res_first=txt2img.super_res_first, + super_res_last=txt2img.super_res_last, + decoder_scheduler=txt2img.decoder_scheduler, + super_res_scheduler=txt2img.super_res_scheduler, + ) + + img2img.save_pretrained(args.dump_path) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 10e7d560b147..46480270e80e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1,7 +1,6 @@ __version__ = "0.12.0.dev0" from .configuration_utils import ConfigMixin -from .onnx_utils import OnnxRuntimeModel from .utils import ( OptionalDependencyNotAvailable, is_flax_available, @@ -18,15 +17,23 @@ ) +try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_onnx_objects import * # noqa F403 +else: + from .pipelines import OnnxRuntimeModel + try: if not is_torch_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: - from .modeling_utils import ModelMixin from .models import ( AutoencoderKL, + ModelMixin, PriorTransformer, Transformer2DModel, UNet1DModel, @@ -43,11 +50,13 @@ get_polynomial_decay_schedule_with_warmup, get_scheduler, ) - from .pipeline_utils import DiffusionPipeline from .pipelines import ( + AudioPipelineOutput, DanceDiffusionPipeline, DDIMPipeline, DDPMPipeline, + DiffusionPipeline, + ImagePipelineOutput, KarrasVePipeline, LDMPipeline, LDMSuperResolutionPipeline, @@ -105,6 +114,7 @@ StableDiffusionPipeline, StableDiffusionPipelineSafe, StableDiffusionUpscalePipeline, + UnCLIPImageVariationPipeline, UnCLIPPipeline, VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, @@ -149,10 +159,10 @@ except OptionalDependencyNotAvailable: from .utils.dummy_flax_objects import * # noqa F403 else: - from .modeling_flax_utils import FlaxModelMixin + from .models.modeling_flax_utils import FlaxModelMixin from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.vae_flax import FlaxAutoencoderKL - from .pipeline_flax_utils import FlaxDiffusionPipeline + from .pipelines import FlaxDiffusionPipeline from .schedulers import ( FlaxDDIMScheduler, FlaxDDPMScheduler, diff --git a/src/diffusers/experimental/rl/value_guided_sampling.py b/src/diffusers/experimental/rl/value_guided_sampling.py index 27bef08182f4..1c84012389a9 100644 --- a/src/diffusers/experimental/rl/value_guided_sampling.py +++ b/src/diffusers/experimental/rl/value_guided_sampling.py @@ -18,7 +18,7 @@ import tqdm from ...models.unet_1d import UNet1DModel -from ...pipeline_utils import DiffusionPipeline +from ...pipelines import DiffusionPipeline from ...utils.dummy_pt_objects import DDPMScheduler diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index d0ee290fbe20..474b8412560e 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -16,12 +16,15 @@ if is_torch_available(): - from .attention import Transformer2DModel + from .autoencoder_kl import AutoencoderKL + from .dual_transformer_2d import DualTransformer2DModel + from .modeling_utils import ModelMixin from .prior_transformer import PriorTransformer + from .transformer_2d import Transformer2DModel from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel - from .vae import AutoencoderKL, VQModel + from .vq_model import VQModel if is_flax_available(): from .unet_2d_condition_flax import FlaxUNet2DConditionModel diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 9fe6a8034c22..acee0ff6b2fb 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -20,11 +20,11 @@ from torch import nn from ..configuration_utils import ConfigMixin, register_to_config -from ..modeling_utils import ModelMixin from ..models.embeddings import ImagePositionalEmbeddings from ..utils import BaseOutput from ..utils.import_utils import is_xformers_available from .cross_attention import CrossAttention +from .modeling_utils import ModelMixin @dataclass @@ -204,17 +204,17 @@ def forward( """ # 1. Input if self.is_input_continuous: - batch, channel, height, weight = hidden_states.shape + batch, channel, height, width = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) if not self.use_linear_projection: hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = self.proj_in(hidden_states) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) @@ -231,15 +231,11 @@ def forward( # 3. Output if self.is_input_continuous: if not self.use_linear_projection: - hidden_states = ( - hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() - ) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = self.proj_out(hidden_states) else: hidden_states = self.proj_out(hidden_states) - hidden_states = ( - hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() - ) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual elif self.is_input_vectorized: @@ -707,7 +703,13 @@ def __init__( self.transformer_index_for_condition = [1, 0] def forward( - self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, return_dict: bool = True + self, + hidden_states, + encoder_hidden_states, + timestep=None, + attention_mask=None, + cross_attention_kwargs=None, + return_dict: bool = True, ): """ Args: @@ -742,6 +744,7 @@ def forward( input_states, encoder_hidden_states=condition_state, timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] encoded_states.append(encoded_state - input_states) diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py new file mode 100644 index 000000000000..1bf86627610d --- /dev/null +++ b/src/diffusers/models/autoencoder_kl.py @@ -0,0 +1,177 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .modeling_utils import ModelMixin +from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder + + +@dataclass +class AutoencoderKLOutput(BaseOutput): + """ + Output of AutoencoderKL encoding method. + + Args: + latent_dist (`DiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. + `DiagonalGaussianDistribution` allows for sampling latents from the distribution. + """ + + latent_dist: "DiagonalGaussianDistribution" + + +class AutoencoderKL(ModelMixin, ConfigMixin): + r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma + and Max Welling. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(64,)`): Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): TODO + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 32, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) + self.use_slicing = False + + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + z = self.post_quant_conv(z) + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/src/diffusers/models/dual_transformer_2d.py b/src/diffusers/models/dual_transformer_2d.py new file mode 100644 index 000000000000..654ce405df67 --- /dev/null +++ b/src/diffusers/models/dual_transformer_2d.py @@ -0,0 +1,151 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from torch import nn + +from .transformer_2d import Transformer2DModel, Transformer2DModelOutput + + +class DualTransformer2DModel(nn.Module): + """ + Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + ): + super().__init__() + self.transformers = nn.ModuleList( + [ + Transformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + ) + for _ in range(2) + ] + ) + + # Variables that can be set by a pipeline: + + # The ratio of transformer1 to transformer2's output states to be combined during inference + self.mix_ratio = 0.5 + + # The shape of `encoder_hidden_states` is expected to be + # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` + self.condition_lengths = [77, 257] + + # Which transformer to use to encode which condition. + # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` + self.transformer_index_for_condition = [1, 0] + + def forward( + self, + hidden_states, + encoder_hidden_states, + timestep=None, + attention_mask=None, + cross_attention_kwargs=None, + return_dict: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + attention_mask (`torch.FloatTensor`, *optional*): + Optional attention mask to be applied in CrossAttention + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] + if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample + tensor. + """ + input_states = hidden_states + + encoded_states = [] + tokens_start = 0 + # attention_mask is not used yet + for i in range(2): + # for each of the two transformers, pass the corresponding condition tokens + condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] + transformer_index = self.transformer_index_for_condition[i] + encoded_state = self.transformers[transformer_index]( + input_states, + encoder_hidden_states=condition_state, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + encoded_states.append(encoded_state - input_states) + tokens_start += self.condition_lengths[i] + + output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) + output_states = output_states + input_states + + if not return_dict: + return (output_states,) + + return Transformer2DModelOutput(sample=output_states) diff --git a/src/diffusers/modeling_flax_pytorch_utils.py b/src/diffusers/models/modeling_flax_pytorch_utils.py similarity index 99% rename from src/diffusers/modeling_flax_pytorch_utils.py rename to src/diffusers/models/modeling_flax_pytorch_utils.py index 9c7a5de2ad6e..7463b408ed21 100644 --- a/src/diffusers/modeling_flax_pytorch_utils.py +++ b/src/diffusers/models/modeling_flax_pytorch_utils.py @@ -19,7 +19,7 @@ from flax.traverse_util import flatten_dict, unflatten_dict from jax.random import PRNGKey -from .utils import logging +from ..utils import logging logger = logging.get_logger(__name__) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/models/modeling_flax_utils.py similarity index 99% rename from src/diffusers/modeling_flax_utils.py rename to src/diffusers/models/modeling_flax_utils.py index 857fdd1b0b33..aeeeccad674b 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/models/modeling_flax_utils.py @@ -27,9 +27,8 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError -from . import __version__, is_torch_available -from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax -from .utils import ( +from .. import __version__, is_torch_available +from ..utils import ( CONFIG_NAME, DIFFUSERS_CACHE, FLAX_WEIGHTS_NAME, @@ -37,6 +36,7 @@ WEIGHTS_NAME, logging, ) +from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax logger = logging.get_logger(__name__) @@ -189,7 +189,7 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): ```""" return self._cast_floating_to(params, jnp.float16, mask) - def init_weights(self, rng: jax.random.PRNGKey) -> Dict: + def init_weights(self, rng: jax.random.KeyArray) -> Dict: raise NotImplementedError(f"init_weights method has to be implemented for {self}") @classmethod diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/models/modeling_utils.py similarity index 97% rename from src/diffusers/modeling_utils.py rename to src/diffusers/models/modeling_utils.py index 6d934e6b3049..91c44973b34f 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os from functools import partial from typing import Callable, List, Optional, Tuple, Union @@ -25,11 +26,11 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError -from . import __version__ -from .hub_utils import HF_HUB_OFFLINE -from .utils import ( +from .. import __version__ +from ..utils import ( CONFIG_NAME, DIFFUSERS_CACHE, + HF_HUB_OFFLINE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, @@ -148,7 +149,7 @@ class ModelMixin(torch.nn.Module): and saving models. - **config_name** ([`str`]) -- A filename under which the model should be stored when calling - [`~modeling_utils.ModelMixin.save_pretrained`]. + [`~models.ModelMixin.save_pretrained`]. """ config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] @@ -230,7 +231,7 @@ def save_pretrained( ): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the - `[`~modeling_utils.ModelMixin.from_pretrained`]` class method. + `[`~models.ModelMixin.from_pretrained`]` class method. Arguments: save_directory (`str` or `os.PathLike`): @@ -489,11 +490,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P state_dict = load_state_dict(model_file) # move the parms from meta device to cpu for param_name, param in state_dict.items(): - set_module_tensor_to_device(model, param_name, param_device, value=param) + accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) + if accepts_dtype: + set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype) + else: + set_module_tensor_to_device(model, param_name, param_device, value=param) else: # else let accelerate handle loading and dispatching. # Load weights and dispatch according to the device_map # by deafult the device_map is None and the weights are loaded on the CPU - accelerate.load_checkpoint_and_dispatch(model, model_file, device_map) + accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype) loading_info = { "missing_keys": [], @@ -519,20 +524,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = cls.from_config(config, **unused_kwargs) state_dict = load_state_dict(model_file) - dtype = set(v.dtype for v in state_dict.values()) - - if len(dtype) > 1 and torch.float32 not in dtype: - raise ValueError( - f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please" - f" make sure that {model_file} weights have only one dtype." - ) - elif len(dtype) > 1 and torch.float32 in dtype: - dtype = torch.float32 - else: - dtype = dtype.pop() - - # move model to correct dtype - model = model.to(dtype) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, diff --git a/src/diffusers/models/prior_transformer.py b/src/diffusers/models/prior_transformer.py index 714d5d52cf2d..998ca494a43d 100644 --- a/src/diffusers/models/prior_transformer.py +++ b/src/diffusers/models/prior_transformer.py @@ -6,10 +6,10 @@ from torch import nn from ..configuration_utils import ConfigMixin, register_to_config -from ..modeling_utils import ModelMixin from ..utils import BaseOutput from .attention import BasicTransformerBlock from .embeddings import TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin @dataclass diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py new file mode 100644 index 000000000000..2513b57ee28e --- /dev/null +++ b/src/diffusers/models/transformer_2d.py @@ -0,0 +1,244 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..models.embeddings import ImagePositionalEmbeddings +from ..utils import BaseOutput +from .attention import BasicTransformerBlock +from .modeling_utils import ModelMixin + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions + for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual + embeddings) inputs. + + When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard + transformer action. Finally, reshape to image. + + When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional + embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict + classes of unnoised image. + + Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised + image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = in_channels is not None + self.is_input_vectorized = num_vector_embeds is not None + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized: + raise ValueError( + f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if self.is_input_continuous: + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + cross_attention_kwargs=None, + return_dict: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] + if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample + tensor. + """ + # 1. Input + if self.is_input_continuous: + batch, channel, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index 00083fb392ff..2ab6c69da72b 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -19,9 +19,9 @@ import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config -from ..modeling_utils import ModelMixin from ..utils import BaseOutput from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index b49931a1368d..ccbb218b6c3b 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -18,9 +18,9 @@ import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config -from ..modeling_utils import ModelMixin from ..utils import BaseOutput from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index d5ccb169e0b4..8099cd8421fb 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -19,10 +19,10 @@ import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config -from ..modeling_utils import ModelMixin from ..utils import BaseOutput, logging from .cross_attention import AttnProcessor from .embeddings import TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin from .unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 3a3f1d9e146d..8d8308c5bfb9 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -20,9 +20,9 @@ from flax.core.frozen_dict import FrozenDict from ..configuration_utils import ConfigMixin, flax_register_to_config -from ..modeling_flax_utils import FlaxModelMixin from ..utils import BaseOutput from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps +from .modeling_flax_utils import FlaxModelMixin from .unet_2d_blocks_flax import ( FlaxCrossAttnDownBlock2D, FlaxCrossAttnUpBlock2D, @@ -112,7 +112,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): flip_sin_to_cos: bool = True freq_shift: int = 0 - def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: + def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index bcfc8789ab5c..f46cf7bdde10 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Optional import numpy as np import torch import torch.nn as nn -from ..configuration_utils import ConfigMixin, register_to_config -from ..modeling_utils import ModelMixin from ..utils import BaseOutput from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @@ -37,33 +35,6 @@ class DecoderOutput(BaseOutput): sample: torch.FloatTensor -@dataclass -class VQEncoderOutput(BaseOutput): - """ - Output of VQModel encoding method. - - Args: - latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Encoded output sample of the model. Output of the last layer of the model. - """ - - latents: torch.FloatTensor - - -@dataclass -class AutoencoderKLOutput(BaseOutput): - """ - Output of AutoencoderKL encoding method. - - Args: - latent_dist (`DiagonalGaussianDistribution`): - Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. - `DiagonalGaussianDistribution` allows for sampling latents from the distribution. - """ - - latent_dist: "DiagonalGaussianDistribution" - - class Encoder(nn.Module): def __init__( self, @@ -384,255 +355,3 @@ def nll(self, sample, dims=[1, 2, 3]): def mode(self): return self.mean - - -class VQModel(ModelMixin, ConfigMixin): - r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray - Kavukcuoglu. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) - - Parameters: - in_channels (int, *optional*, defaults to 3): Number of channels in the input image. - out_channels (int, *optional*, defaults to 3): Number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to : - obj:`(64,)`): Tuple of block output channels. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. - sample_size (`int`, *optional*, defaults to `32`): TODO - num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. - vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. - """ - - @register_to_config - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock2D",), - up_block_types: Tuple[str] = ("UpDecoderBlock2D",), - block_out_channels: Tuple[int] = (64,), - layers_per_block: int = 1, - act_fn: str = "silu", - latent_channels: int = 3, - sample_size: int = 32, - num_vq_embeddings: int = 256, - norm_num_groups: int = 32, - vq_embed_dim: Optional[int] = None, - ): - super().__init__() - - # pass init params to Encoder - self.encoder = Encoder( - in_channels=in_channels, - out_channels=latent_channels, - down_block_types=down_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - double_z=False, - ) - - vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels - - self.quant_conv = torch.nn.Conv2d(latent_channels, vq_embed_dim, 1) - self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) - self.post_quant_conv = torch.nn.Conv2d(vq_embed_dim, latent_channels, 1) - - # pass init params to Decoder - self.decoder = Decoder( - in_channels=latent_channels, - out_channels=out_channels, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - ) - - def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: - h = self.encoder(x) - h = self.quant_conv(h) - - if not return_dict: - return (h,) - - return VQEncoderOutput(latents=h) - - def decode( - self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True - ) -> Union[DecoderOutput, torch.FloatTensor]: - # also go through quantization layer - if not force_not_quantize: - quant, emb_loss, info = self.quantize(h) - else: - quant = h - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - - def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: - r""" - Args: - sample (`torch.FloatTensor`): Input sample. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`DecoderOutput`] instead of a plain tuple. - """ - x = sample - h = self.encode(x).latents - dec = self.decode(h).sample - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - - -class AutoencoderKL(ModelMixin, ConfigMixin): - r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma - and Max Welling. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) - - Parameters: - in_channels (int, *optional*, defaults to 3): Number of channels in the input image. - out_channels (int, *optional*, defaults to 3): Number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to : - obj:`(64,)`): Tuple of block output channels. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space. - sample_size (`int`, *optional*, defaults to `32`): TODO - """ - - @register_to_config - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock2D",), - up_block_types: Tuple[str] = ("UpDecoderBlock2D",), - block_out_channels: Tuple[int] = (64,), - layers_per_block: int = 1, - act_fn: str = "silu", - latent_channels: int = 4, - norm_num_groups: int = 32, - sample_size: int = 32, - ): - super().__init__() - - # pass init params to Encoder - self.encoder = Encoder( - in_channels=in_channels, - out_channels=latent_channels, - down_block_types=down_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - double_z=True, - ) - - # pass init params to Decoder - self.decoder = Decoder( - in_channels=latent_channels, - out_channels=out_channels, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - norm_num_groups=norm_num_groups, - act_fn=act_fn, - ) - - self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) - self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) - self.use_slicing = False - - def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - - if not return_dict: - return (posterior,) - - return AutoencoderKLOutput(latent_dist=posterior) - - def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: - z = self.post_quant_conv(z) - dec = self.decoder(z) - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - - def enable_slicing(self): - r""" - Enable sliced VAE decoding. - - When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - - def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: - if self.use_slicing and z.shape[0] > 1: - decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] - decoded = torch.cat(decoded_slices) - else: - decoded = self._decode(z).sample - - if not return_dict: - return (decoded,) - - return DecoderOutput(sample=decoded) - - def forward( - self, - sample: torch.FloatTensor, - sample_posterior: bool = False, - return_dict: bool = True, - generator: Optional[torch.Generator] = None, - ) -> Union[DecoderOutput, torch.FloatTensor]: - r""" - Args: - sample (`torch.FloatTensor`): Input sample. - sample_posterior (`bool`, *optional*, defaults to `False`): - Whether to sample from the posterior. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`DecoderOutput`] instead of a plain tuple. - """ - x = sample - posterior = self.encode(x).latent_dist - if sample_posterior: - z = posterior.sample(generator=generator) - else: - z = posterior.mode() - dec = self.decode(z).sample - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index 7ecda9a6e9a0..4533bb5551f7 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -25,8 +25,8 @@ from flax.core.frozen_dict import FrozenDict from ..configuration_utils import ConfigMixin, flax_register_to_config -from ..modeling_flax_utils import FlaxModelMixin from ..utils import BaseOutput +from .modeling_flax_utils import FlaxModelMixin @flax.struct.dataclass @@ -806,7 +806,7 @@ def setup(self): dtype=self.dtype, ) - def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: + def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py new file mode 100644 index 000000000000..18fa80cb7b6b --- /dev/null +++ b/src/diffusers/models/vq_model.py @@ -0,0 +1,148 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .modeling_utils import ModelMixin +from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer + + +@dataclass +class VQEncoderOutput(BaseOutput): + """ + Output of VQModel encoding method. + + Args: + latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Encoded output sample of the model. Output of the last layer of the model. + """ + + latents: torch.FloatTensor + + +class VQModel(ModelMixin, ConfigMixin): + r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray + Kavukcuoglu. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(64,)`): Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): TODO + num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. + vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 3, + sample_size: int = 32, + num_vq_embeddings: int = 256, + norm_num_groups: int = 32, + vq_embed_dim: Optional[int] = None, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=False, + ) + + vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels + + self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1) + self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) + self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + ) + + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: + h = self.encoder(x) + h = self.quant_conv(h) + + if not return_dict: + return (h,) + + return VQEncoderOutput(latents=h) + + def decode( + self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + h = self.encode(x).latents + dec = self.decode(h).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 23c4b29a53ae..e8df9e93eb6e 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -1,6 +1,4 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,867 +10,10 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. - -import importlib -import inspect -import os -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -import numpy as np -import torch - -import diffusers -import PIL -from huggingface_hub import model_info, snapshot_download -from packaging import version -from PIL import Image -from tqdm.auto import tqdm - -from .configuration_utils import ConfigMixin -from .dynamic_modules_utils import get_class_from_dynamic_module -from .hub_utils import HF_HUB_OFFLINE, http_user_agent -from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT -from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from .utils import ( - CONFIG_NAME, - DIFFUSERS_CACHE, - ONNX_WEIGHTS_NAME, - WEIGHTS_NAME, - BaseOutput, - deprecate, - is_accelerate_available, - is_safetensors_available, - is_torch_version, - is_transformers_available, - logging, -) - - -if is_transformers_available(): - import transformers - from transformers import PreTrainedModel - - -INDEX_FILE = "diffusion_pytorch_model.bin" -CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" -DUMMY_MODULES_FOLDER = "diffusers.utils" -TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" - - -logger = logging.get_logger(__name__) - - -LOADABLE_CLASSES = { - "diffusers": { - "ModelMixin": ["save_pretrained", "from_pretrained"], - "SchedulerMixin": ["save_pretrained", "from_pretrained"], - "DiffusionPipeline": ["save_pretrained", "from_pretrained"], - "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], - }, - "transformers": { - "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], - "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], - "PreTrainedModel": ["save_pretrained", "from_pretrained"], - "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], - "ProcessorMixin": ["save_pretrained", "from_pretrained"], - "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], - }, - "onnxruntime.training": { - "ORTModule": ["save_pretrained", "from_pretrained"], - }, -} - -ALL_IMPORTABLE_CLASSES = {} -for library in LOADABLE_CLASSES: - ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) - - -@dataclass -class ImagePipelineOutput(BaseOutput): - """ - Output class for image pipelines. - - Args: - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, - num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. - """ - - images: Union[List[PIL.Image.Image], np.ndarray] - - -@dataclass -class AudioPipelineOutput(BaseOutput): - """ - Output class for audio pipelines. - - Args: - audios (`np.ndarray`) - List of denoised samples of shape `(batch_size, num_channels, sample_rate)`. Numpy array present the - denoised audio samples of the diffusion pipeline. - """ - - audios: np.ndarray - - -def is_safetensors_compatible(info) -> bool: - filenames = set(sibling.rfilename for sibling in info.siblings) - pt_filenames = set(filename for filename in filenames if filename.endswith(".bin")) - is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames) - for pt_filename in pt_filenames: - prefix, raw = os.path.split(pt_filename) - if raw == "pytorch_model.bin": - # transformers specific - sf_filename = os.path.join(prefix, "model.safetensors") - else: - sf_filename = pt_filename[: -len(".bin")] + ".safetensors" - if is_safetensors_compatible and sf_filename not in filenames: - logger.warning(f"{sf_filename} not found") - is_safetensors_compatible = False - return is_safetensors_compatible - - -class DiffusionPipeline(ConfigMixin): - r""" - Base class for all models. - - [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines - and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to: - - - move all PyTorch modules to the device of your choice - - enabling/disabling the progress bar for the denoising iteration - - Class attributes: - - - **config_name** (`str`) -- name of the config file that will store the class and module names of all - components of the diffusion pipeline. - - **_optional_components** (List[`str`]) -- list of all components that are optional so they don't have to be - passed for the pipeline to function (should be overridden by subclasses). - """ - config_name = "model_index.json" - _optional_components = [] - - def register_modules(self, **kwargs): - # import it here to avoid circular import - from diffusers import pipelines - - for name, module in kwargs.items(): - # retrieve library - if module is None: - register_dict = {name: (None, None)} - else: - library = module.__module__.split(".")[0] - - # check if the module is a pipeline module - pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None - path = module.__module__.split(".") - is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) - - # if library is not in LOADABLE_CLASSES, then it is a custom module. - # Or if it's a pipeline module, then the module is inside the pipeline - # folder so we set the library to module name. - if library not in LOADABLE_CLASSES or is_pipeline_module: - library = pipeline_dir - - # retrieve class_name - class_name = module.__class__.__name__ - - register_dict = {name: (library, class_name)} - - # save model index config - self.register_to_config(**register_dict) - - # set models - setattr(self, name, module) - - def save_pretrained( - self, - save_directory: Union[str, os.PathLike], - safe_serialization: bool = False, - ): - """ - Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to - a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading - method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - safe_serialization (`bool`, *optional*, defaults to `False`): - Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). - """ - self.save_config(save_directory) - - model_index_dict = dict(self.config) - model_index_dict.pop("_class_name") - model_index_dict.pop("_diffusers_version") - model_index_dict.pop("_module", None) - - expected_modules, optional_kwargs = self._get_signature_keys(self) - - def is_saveable_module(name, value): - if name not in expected_modules: - return False - if name in self._optional_components and value[0] is None: - return False - return True - - model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)} - - for pipeline_component_name in model_index_dict.keys(): - sub_model = getattr(self, pipeline_component_name) - model_cls = sub_model.__class__ - - save_method_name = None - # search for the model's base class in LOADABLE_CLASSES - for library_name, library_classes in LOADABLE_CLASSES.items(): - library = importlib.import_module(library_name) - for base_class, save_load_methods in library_classes.items(): - class_candidate = getattr(library, base_class, None) - if class_candidate is not None and issubclass(model_cls, class_candidate): - # if we found a suitable base class in LOADABLE_CLASSES then grab its save method - save_method_name = save_load_methods[0] - break - if save_method_name is not None: - break - - save_method = getattr(sub_model, save_method_name) - - # Call the save method with the argument safe_serialization only if it's supported - save_method_signature = inspect.signature(save_method) - save_method_accept_safe = "safe_serialization" in save_method_signature.parameters - if save_method_accept_safe: - save_method( - os.path.join(save_directory, pipeline_component_name), safe_serialization=safe_serialization - ) - else: - save_method(os.path.join(save_directory, pipeline_component_name)) - - def to(self, torch_device: Optional[Union[str, torch.device]] = None): - if torch_device is None: - return self - - module_names, _, _ = self.extract_init_dict(dict(self.config)) - for name in module_names.keys(): - module = getattr(self, name) - if isinstance(module, torch.nn.Module): - if module.dtype == torch.float16 and str(torch_device) in ["cpu"]: - logger.warning( - "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It" - " is not recommended to move them to `cpu` as running them will fail. Please make" - " sure to use an accelerator to run the pipeline in inference, due to the lack of" - " support for`float16` operations on this device in PyTorch. Please, remove the" - " `torch_dtype=torch.float16` argument, or use another device for inference." - ) - module.to(torch_device) - return self - - @property - def device(self) -> torch.device: - r""" - Returns: - `torch.device`: The torch device on which the pipeline is located. - """ - module_names, _, _ = self.extract_init_dict(dict(self.config)) - for name in module_names.keys(): - module = getattr(self, name) - if isinstance(module, torch.nn.Module): - return module.device - return torch.device("cpu") - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): - r""" - Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights. - - The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). - - The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come - pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning - task. - - The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those - weights are discarded. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on - https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like - `CompVis/ldm-text2im-large-256`. - - A path to a *directory* containing pipeline weights saved using - [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype - will be automatically derived from the model's weights. - custom_pipeline (`str`, *optional*): - - - - This is an experimental feature and is likely to change in the future. - - - - Can be either: - - - A string, the *repo id* of a custom pipeline hosted inside a model repo on - https://huggingface.co/. Valid repo ids have to be located under a user or organization name, - like `hf-internal-testing/diffusers-dummy-pipeline`. - - - - It is required that the model repo has a file, called `pipeline.py` that defines the custom - pipeline. - - - - - A string, the *file name* of a community pipeline hosted on GitHub under - https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to - match exactly the file name without `.py` located under the above link, *e.g.* - `clip_guided_stable_diffusion`. - - - - Community pipelines are always loaded from the current `main` branch of GitHub. - - - - - A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`. - - - - It is required that the directory has a file, called `pipeline.py` that defines the custom - pipeline. - - - - For more information on how to load and create custom pipelines, please have a look at [Loading and - Adding Custom - Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview) - - torch_dtype (`str` or `torch.dtype`, *optional*): - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `huggingface-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - custom_revision (`str`, *optional*, defaults to `"main"` when loading from the Hub and to local version of `diffusers` when loading from GitHub): - The specific model version to use. It can be a branch name, a tag name, or a commit id similar to - `revision` when loading a custom pipeline from the Hub. It can be a diffusers version when loading a - custom pipeline from GitHub. - mirror (`str`, *optional*): - Mirror source to accelerate downloads in China. If you are from China and have an accessibility - problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. - Please refer to the mirror site for more information. specify the folder name here. - device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): - A map that specifies where each submodule should go. It doesn't need to be refined to each - parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the - same device. - - To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For - more information about each option see [designing a device - map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Speed up model loading by not initializing the weights and only loading the pre-trained weights. This - also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the - model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, - setting this argument to `True` will raise an error. - return_cached_folder (`bool`, *optional*, defaults to `False`): - If set to `True`, path to downloaded cached folder will be returned in addition to loaded pipeline. - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the - specific pipeline class. The overwritten components are then directly passed to the pipelines - `__init__` method. See example below for more information. - - - - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"` - - - - - Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use - this method in a firewalled environment. - - - - Examples: - - ```py - >>> from diffusers import DiffusionPipeline - - >>> # Download pipeline from huggingface.co and cache. - >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") - - >>> # Download pipeline that requires an authorization token - >>> # For more information on access tokens, please refer to this section - >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) - >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") - - >>> # Use a different scheduler - >>> from diffusers import LMSDiscreteScheduler - - >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) - >>> pipeline.scheduler = scheduler - ``` - """ - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - resume_download = kwargs.pop("resume_download", False) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) - custom_pipeline = kwargs.pop("custom_pipeline", None) - custom_revision = kwargs.pop("custom_revision", None) - provider = kwargs.pop("provider", None) - sess_options = kwargs.pop("sess_options", None) - device_map = kwargs.pop("device_map", None) - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) - return_cached_folder = kwargs.pop("return_cached_folder", False) - - # 1. Download the checkpoints and configs - # use snapshot download here to get it working from from_pretrained - if not os.path.isdir(pretrained_model_name_or_path): - config_dict = cls.load_config( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - ) - # make sure we only download sub-folders and `diffusers` filenames - folder_names = [k for k in config_dict.keys() if not k.startswith("_")] - allow_patterns = [os.path.join(k, "*") for k in folder_names] - allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name] - - # make sure we don't download flax weights - ignore_patterns = ["*.msgpack"] - - if custom_pipeline is not None: - allow_patterns += [CUSTOM_PIPELINE_FILE_NAME] - - if cls != DiffusionPipeline: - requested_pipeline_class = cls.__name__ - else: - requested_pipeline_class = config_dict.get("_class_name", cls.__name__) - user_agent = {"pipeline_class": requested_pipeline_class} - if custom_pipeline is not None and not custom_pipeline.endswith(".py"): - user_agent["custom_pipeline"] = custom_pipeline - - user_agent = http_user_agent(user_agent) - - if is_safetensors_available(): - info = model_info( - pretrained_model_name_or_path, - use_auth_token=use_auth_token, - revision=revision, - ) - if is_safetensors_compatible(info): - ignore_patterns.append("*.bin") - - # download all allow_patterns - cached_folder = snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - user_agent=user_agent, - ) - else: - cached_folder = pretrained_model_name_or_path - - config_dict = cls.load_config(cached_folder) - - # 2. Load the pipeline class, if using custom module then load it from the hub - # if we load from explicit class, let's use it - if custom_pipeline is not None: - if custom_pipeline.endswith(".py"): - path = Path(custom_pipeline) - # decompose into folder & file - file_name = path.name - custom_pipeline = path.parent.absolute() - else: - file_name = CUSTOM_PIPELINE_FILE_NAME - - pipeline_class = get_class_from_dynamic_module( - custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=custom_revision - ) - elif cls != DiffusionPipeline: - pipeline_class = cls - else: - diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) - pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) - - # To be removed in 1.0.0 - if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( - version.parse(config_dict["_diffusers_version"]).base_version - ) <= version.parse("0.5.1"): - from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy - - pipeline_class = StableDiffusionInpaintPipelineLegacy - - deprecation_message = ( - "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the" - f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For" - " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting" - " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your" - f" checkpoint {pretrained_model_name_or_path} to the format of" - " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain" - " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0." - ) - deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False) - - # some modules can be passed directly to the init - # in this case they are already instantiated in `kwargs` - # extract them here - expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) - passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} - passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} - - init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) - - # define init kwargs - init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict} - init_kwargs = {**init_kwargs, **passed_pipe_kwargs} - - # remove `null` components - def load_module(name, value): - if value[0] is None: - return False - if name in passed_class_obj and passed_class_obj[name] is None: - return False - return True - - init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} - - if len(unused_kwargs) > 0: - logger.warning( - f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." - ) - - if low_cpu_mem_usage and not is_accelerate_available(): - low_cpu_mem_usage = False - logger.warning( - "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" - " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" - " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" - " install accelerate\n```\n." - ) - - if device_map is not None and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `device_map=None`." - ) - - if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `low_cpu_mem_usage=False`." - ) - - if low_cpu_mem_usage is False and device_map is not None: - raise ValueError( - f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" - " dispatching. Please make sure to set `low_cpu_mem_usage=True`." - ) - - # import it here to avoid circular import - from diffusers import pipelines - - # 3. Load each module in the pipeline - for name, (library_name, class_name) in init_dict.items(): - # 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names - if class_name.startswith("Flax"): - class_name = class_name[4:] - - is_pipeline_module = hasattr(pipelines, library_name) - loaded_sub_model = None - - # if the model is in a pipeline module, then we load it from the pipeline - if name in passed_class_obj: - # 1. check that passed_class_obj has correct parent class - if not is_pipeline_module: - library = importlib.import_module(library_name) - class_obj = getattr(library, class_name) - importable_classes = LOADABLE_CLASSES[library_name] - class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} - - expected_class_obj = None - for class_name, class_candidate in class_candidates.items(): - if class_candidate is not None and issubclass(class_obj, class_candidate): - expected_class_obj = class_candidate - - if not issubclass(passed_class_obj[name].__class__, expected_class_obj): - raise ValueError( - f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" - f" {expected_class_obj}" - ) - else: - logger.warning( - f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" - " has the correct type" - ) - - # set passed class object - loaded_sub_model = passed_class_obj[name] - elif is_pipeline_module: - pipeline_module = getattr(pipelines, library_name) - class_obj = getattr(pipeline_module, class_name) - importable_classes = ALL_IMPORTABLE_CLASSES - class_candidates = {c: class_obj for c in importable_classes.keys()} - else: - # else we just import it from the library. - library = importlib.import_module(library_name) - - class_obj = getattr(library, class_name) - importable_classes = LOADABLE_CLASSES[library_name] - class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} - - if loaded_sub_model is None: - load_method_name = None - for class_name, class_candidate in class_candidates.items(): - if class_candidate is not None and issubclass(class_obj, class_candidate): - load_method_name = importable_classes[class_name][1] - - if load_method_name is None: - none_module = class_obj.__module__ - is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( - TRANSFORMERS_DUMMY_MODULES_FOLDER - ) - if is_dummy_path and "dummy" in none_module: - # call class_obj for nice error message of missing requirements - class_obj() - - raise ValueError( - f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" - f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." - ) - - load_method = getattr(class_obj, load_method_name) - loading_kwargs = {} - - if issubclass(class_obj, torch.nn.Module): - loading_kwargs["torch_dtype"] = torch_dtype - if issubclass(class_obj, diffusers.OnnxRuntimeModel): - loading_kwargs["provider"] = provider - loading_kwargs["sess_options"] = sess_options - - is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin) - is_transformers_model = ( - is_transformers_available() - and issubclass(class_obj, PreTrainedModel) - and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0") - ) - - # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. - # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. - # This makes sure that the weights won't be initialized which significantly speeds up loading. - if is_diffusers_model or is_transformers_model: - loading_kwargs["device_map"] = device_map - loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - - # check if the module is in a subdirectory - if os.path.isdir(os.path.join(cached_folder, name)): - loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) - else: - # else load from the root directory - loaded_sub_model = load_method(cached_folder, **loading_kwargs) - - init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) - - # 4. Potentially add passed objects if expected - missing_modules = set(expected_modules) - set(init_kwargs.keys()) - passed_modules = list(passed_class_obj.keys()) - optional_modules = pipeline_class._optional_components - if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules): - for module in missing_modules: - init_kwargs[module] = passed_class_obj.get(module, None) - elif len(missing_modules) > 0: - passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs - raise ValueError( - f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." - ) - - # 5. Instantiate the pipeline - model = pipeline_class(**init_kwargs) - - if return_cached_folder: - return model, cached_folder - return model - - @staticmethod - def _get_signature_keys(obj): - parameters = inspect.signature(obj.__init__).parameters - required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} - optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) - expected_modules = set(required_parameters.keys()) - set(["self"]) - return expected_modules, optional_parameters - - @property - def components(self) -> Dict[str, Any]: - r""" - - The `self.components` property can be useful to run different pipelines with the same weights and - configurations to not have to re-allocate memory. - - Examples: - - ```py - >>> from diffusers import ( - ... StableDiffusionPipeline, - ... StableDiffusionImg2ImgPipeline, - ... StableDiffusionInpaintPipeline, - ... ) - - >>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") - >>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components) - >>> inpaint = StableDiffusionInpaintPipeline(**text2img.components) - ``` - - Returns: - A dictionaly containing all the modules needed to initialize the pipeline. - """ - expected_modules, optional_parameters = self._get_signature_keys(self) - components = { - k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters - } - - if set(components.keys()) != expected_modules: - raise ValueError( - f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" - f" {expected_modules} to be defined, but {components} are defined." - ) - - return components - - @staticmethod - def numpy_to_pil(images): - """ - Convert a numpy image or a batch of images to a PIL image. - """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - if images.shape[-1] == 1: - # special case for grayscale (single channel) images - pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] - else: - pil_images = [Image.fromarray(image) for image in images] - - return pil_images - - def progress_bar(self, iterable=None, total=None): - if not hasattr(self, "_progress_bar_config"): - self._progress_bar_config = {} - elif not isinstance(self._progress_bar_config, dict): - raise ValueError( - f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." - ) - - if iterable is not None: - return tqdm(iterable, **self._progress_bar_config) - elif total is not None: - return tqdm(total=total, **self._progress_bar_config) - else: - raise ValueError("Either `total` or `iterable` has to be defined.") - - def set_progress_bar_config(self, **kwargs): - self._progress_bar_config = kwargs - - def enable_xformers_memory_efficient_attention(self): - r""" - Enable memory efficient attention as implemented in xformers. - - When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference - time. Speed up at training time is not guaranteed. - - Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention - is used. - """ - self.set_use_memory_efficient_attention_xformers(True) - - def disable_xformers_memory_efficient_attention(self): - r""" - Disable memory efficient attention as implemented in xformers. - """ - self.set_use_memory_efficient_attention_xformers(False) - - def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None: - # Recursively walk through all the children. - # Any children which exposes the set_use_memory_efficient_attention_xformers method - # gets the message - def fn_recursive_set_mem_eff(module: torch.nn.Module): - if hasattr(module, "set_use_memory_efficient_attention_xformers"): - module.set_use_memory_efficient_attention_xformers(valid) - - for child in module.children(): - fn_recursive_set_mem_eff(child) - - module_names, _, _ = self.extract_init_dict(dict(self.config)) - for module_name in module_names: - module = getattr(self, module_name) - if isinstance(module, torch.nn.Module): - fn_recursive_set_mem_eff(module) - - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - self.set_attention_slice(slice_size) +# limitations under the License. - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) +# NOTE: This file is deprecated and will be removed in a future version. +# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works - def set_attention_slice(self, slice_size: Optional[int]): - module_names, _, _ = self.extract_init_dict(dict(self.config)) - for module_name in module_names: - module = getattr(self, module_name) - if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size) +from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401 diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7cecfe569234..dcf77c1a43f9 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -20,6 +20,7 @@ from .ddpm import DDPMPipeline from .latent_diffusion import LDMSuperResolutionPipeline from .latent_diffusion_uncond import LDMPipeline + from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput from .pndm import PNDMPipeline from .repaint import RePaintPipeline from .score_sde_ve import ScoreSdeVePipeline @@ -53,7 +54,7 @@ StableDiffusionUpscalePipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe - from .unclip import UnCLIPPipeline + from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline from .versatile_diffusion import ( VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, @@ -62,6 +63,14 @@ ) from .vq_diffusion import VQDiffusionPipeline +try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_onnx_objects import * # noqa F403 +else: + from .onnx_utils import OnnxRuntimeModel + try: if not (is_torch_available() and is_transformers_available() and is_onnx_available()): raise OptionalDependencyNotAvailable() @@ -84,6 +93,14 @@ else: from .stable_diffusion import StableDiffusionKDiffusionPipeline +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_flax_objects import * # noqa F403 +else: + from .pipeline_flax_utils import FlaxDiffusionPipeline + try: if not (is_flax_available() and is_transformers_available()): diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 932c4afc5a9c..ab849c2f3991 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -23,7 +23,6 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, @@ -33,6 +32,7 @@ PNDMScheduler, ) from ...utils import deprecate, logging +from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 6d712cac9eea..c5579e4df456 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -25,7 +25,6 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, @@ -35,6 +34,7 @@ PNDMScheduler, ) from ...utils import PIL_INTERPOLATION, deprecate, logging +from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation diff --git a/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py b/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py index 1d7722125d9c..120c16900751 100644 --- a/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py +++ b/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py @@ -22,8 +22,8 @@ from PIL import Image from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput from ...schedulers import DDIMScheduler, DDPMScheduler +from ..pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput from .mel import Mel diff --git a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py index 903b50cf336c..4bf93417b535 100644 --- a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +++ b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py @@ -17,8 +17,8 @@ import torch -from ...pipeline_utils import AudioPipelineOutput, DiffusionPipeline from ...utils import logging +from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -63,12 +63,11 @@ def __call__( The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.* `sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipeline_utils.AudioPipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. Returns: - [`~pipeline_utils.AudioPipelineOutput`] or `tuple`: [`~pipelines.utils.AudioPipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. + [`~pipelines.AudioPipelineOutput`] or `tuple`: [`~pipelines.utils.AudioPipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ if audio_length_in_s is None: diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index a3d4b589e700..5489abf393a6 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -16,8 +16,8 @@ import torch -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...utils import deprecate +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class DDIMPipeline(DiffusionPipeline): @@ -66,12 +66,11 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ if ( diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index 123b4f844c2a..f10e3aa9c482 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -18,8 +18,8 @@ import torch from ...configuration_utils import FrozenDict -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...utils import deprecate +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class DDPMPipeline(DiffusionPipeline): @@ -62,12 +62,11 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ message = ( "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index ec0c71af4f0f..ca3408d83fcd 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -19,16 +19,14 @@ import torch.nn as nn import torch.utils.checkpoint +from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer from transformers.activations import ACT2FN -from transformers.configuration_utils import PretrainedConfig from transformers.modeling_outputs import BaseModelOutput -from transformers.modeling_utils import PreTrainedModel -from transformers.tokenization_utils import PreTrainedTokenizer from transformers.utils import logging from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class LDMTextToImagePipeline(DiffusionPipeline): @@ -105,12 +103,11 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py index a97d18c9fee6..18b8c4988015 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py @@ -8,7 +8,6 @@ import PIL from ...models import UNet2DModel, VQModel -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, @@ -18,6 +17,7 @@ PNDMScheduler, ) from ...utils import PIL_INTERPOLATION, deprecate +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput def preprocess(image): @@ -95,12 +95,11 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ message = "Please use `image` instead of `init_image`." init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs) diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index 19380d36246a..d8717023b42c 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -18,8 +18,8 @@ import torch from ...models import UNet2DModel, VQModel -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import DDIMScheduler +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class LDMPipeline(DiffusionPipeline): @@ -64,12 +64,11 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ latents = torch.randn( diff --git a/src/diffusers/onnx_utils.py b/src/diffusers/pipelines/onnx_utils.py similarity index 98% rename from src/diffusers/onnx_utils.py rename to src/diffusers/pipelines/onnx_utils.py index b2c533ed741f..9308a1878845 100644 --- a/src/diffusers/onnx_utils.py +++ b/src/diffusers/pipelines/onnx_utils.py @@ -24,7 +24,7 @@ from huggingface_hub import hf_hub_download -from .utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging +from ..utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging if is_onnx_available(): diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index f9de7d911068..b97e46b34353 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -23,9 +23,9 @@ from transformers import CLIPFeatureExtractor from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import logging +from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from .image_encoder import PaintByExampleImageEncoder diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py similarity index 91% rename from src/diffusers/pipeline_flax_utils.py rename to src/diffusers/pipelines/pipeline_flax_utils.py index f8fd304776d7..1922d0fad6e5 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipelines/pipeline_flax_utils.py @@ -17,7 +17,7 @@ import importlib import inspect import os -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import numpy as np @@ -28,11 +28,10 @@ from PIL import Image from tqdm.auto import tqdm -from .configuration_utils import ConfigMixin -from .hub_utils import http_user_agent -from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin -from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin -from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging +from ..configuration_utils import ConfigMixin +from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin +from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin +from ..utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, http_user_agent, is_transformers_available, logging if is_transformers_available(): @@ -475,6 +474,51 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = pipeline_class(**init_kwargs, dtype=dtype) return model, params + @staticmethod + def _get_signature_keys(obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - set(["self"]) + return expected_modules, optional_parameters + + @property + def components(self) -> Dict[str, Any]: + r""" + + The `self.components` property can be useful to run different pipelines with the same weights and + configurations to not have to re-allocate memory. + + Examples: + + ```py + >>> from diffusers import ( + ... FlaxStableDiffusionPipeline, + ... FlaxStableDiffusionImg2ImgPipeline, + ... ) + + >>> text2img = FlaxStableDiffusionPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jnp.bfloat16 + ... ) + >>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components) + ``` + + Returns: + A dictionary containing all the modules needed to initialize the pipeline. + """ + expected_modules, optional_parameters = self._get_signature_keys(self) + components = { + k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters + } + + if set(components.keys()) != expected_modules: + raise ValueError( + f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" + f" {expected_modules} to be defined, but {components} are defined." + ) + + return components + @staticmethod def numpy_to_pil(images): """ diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py new file mode 100644 index 000000000000..854a003e8967 --- /dev/null +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -0,0 +1,881 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch + +import diffusers +import PIL +from huggingface_hub import model_info, snapshot_download +from packaging import version +from PIL import Image +from tqdm.auto import tqdm + +from ..configuration_utils import ConfigMixin +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT +from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME +from ..utils import ( + CONFIG_NAME, + DIFFUSERS_CACHE, + HF_HUB_OFFLINE, + ONNX_WEIGHTS_NAME, + WEIGHTS_NAME, + BaseOutput, + deprecate, + get_class_from_dynamic_module, + http_user_agent, + is_accelerate_available, + is_safetensors_available, + is_torch_version, + is_transformers_available, + logging, +) + + +if is_transformers_available(): + import transformers + from transformers import PreTrainedModel + + +INDEX_FILE = "diffusion_pytorch_model.bin" +CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" +DUMMY_MODULES_FOLDER = "diffusers.utils" +TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" + + +logger = logging.get_logger(__name__) + + +LOADABLE_CLASSES = { + "diffusers": { + "ModelMixin": ["save_pretrained", "from_pretrained"], + "SchedulerMixin": ["save_pretrained", "from_pretrained"], + "DiffusionPipeline": ["save_pretrained", "from_pretrained"], + "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], + }, + "transformers": { + "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], + "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], + "PreTrainedModel": ["save_pretrained", "from_pretrained"], + "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], + "ProcessorMixin": ["save_pretrained", "from_pretrained"], + "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], + }, + "onnxruntime.training": { + "ORTModule": ["save_pretrained", "from_pretrained"], + }, +} + +ALL_IMPORTABLE_CLASSES = {} +for library in LOADABLE_CLASSES: + ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) + + +@dataclass +class ImagePipelineOutput(BaseOutput): + """ + Output class for image pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +@dataclass +class AudioPipelineOutput(BaseOutput): + """ + Output class for audio pipelines. + + Args: + audios (`np.ndarray`) + List of denoised samples of shape `(batch_size, num_channels, sample_rate)`. Numpy array present the + denoised audio samples of the diffusion pipeline. + """ + + audios: np.ndarray + + +def is_safetensors_compatible(info) -> bool: + filenames = set(sibling.rfilename for sibling in info.siblings) + pt_filenames = set(filename for filename in filenames if filename.endswith(".bin")) + is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames) + for pt_filename in pt_filenames: + prefix, raw = os.path.split(pt_filename) + if raw == "pytorch_model.bin": + # transformers specific + sf_filename = os.path.join(prefix, "model.safetensors") + else: + sf_filename = pt_filename[: -len(".bin")] + ".safetensors" + if is_safetensors_compatible and sf_filename not in filenames: + logger.warning(f"{sf_filename} not found") + is_safetensors_compatible = False + return is_safetensors_compatible + + +class DiffusionPipeline(ConfigMixin): + r""" + Base class for all models. + + [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines + and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to: + + - move all PyTorch modules to the device of your choice + - enabling/disabling the progress bar for the denoising iteration + + Class attributes: + + - **config_name** (`str`) -- name of the config file that will store the class and module names of all + components of the diffusion pipeline. + - **_optional_components** (List[`str`]) -- list of all components that are optional so they don't have to be + passed for the pipeline to function (should be overridden by subclasses). + """ + config_name = "model_index.json" + _optional_components = [] + + def register_modules(self, **kwargs): + # import it here to avoid circular import + from diffusers import pipelines + + for name, module in kwargs.items(): + # retrieve library + if module is None: + register_dict = {name: (None, None)} + else: + library = module.__module__.split(".")[0] + + # check if the module is a pipeline module + pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None + path = module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if library not in LOADABLE_CLASSES or is_pipeline_module: + library = pipeline_dir + + # retrieve class_name + class_name = module.__class__.__name__ + + register_dict = {name: (library, class_name)} + + # save model index config + self.register_to_config(**register_dict) + + # set models + setattr(self, name, module) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + safe_serialization: bool = False, + ): + """ + Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to + a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading + method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + """ + self.save_config(save_directory) + + model_index_dict = dict(self.config) + model_index_dict.pop("_class_name") + model_index_dict.pop("_diffusers_version") + model_index_dict.pop("_module", None) + + expected_modules, optional_kwargs = self._get_signature_keys(self) + + def is_saveable_module(name, value): + if name not in expected_modules: + return False + if name in self._optional_components and value[0] is None: + return False + return True + + model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)} + + for pipeline_component_name in model_index_dict.keys(): + sub_model = getattr(self, pipeline_component_name) + model_cls = sub_model.__class__ + + save_method_name = None + # search for the model's base class in LOADABLE_CLASSES + for library_name, library_classes in LOADABLE_CLASSES.items(): + library = importlib.import_module(library_name) + for base_class, save_load_methods in library_classes.items(): + class_candidate = getattr(library, base_class, None) + if class_candidate is not None and issubclass(model_cls, class_candidate): + # if we found a suitable base class in LOADABLE_CLASSES then grab its save method + save_method_name = save_load_methods[0] + break + if save_method_name is not None: + break + + save_method = getattr(sub_model, save_method_name) + + # Call the save method with the argument safe_serialization only if it's supported + save_method_signature = inspect.signature(save_method) + save_method_accept_safe = "safe_serialization" in save_method_signature.parameters + if save_method_accept_safe: + save_method( + os.path.join(save_directory, pipeline_component_name), safe_serialization=safe_serialization + ) + else: + save_method(os.path.join(save_directory, pipeline_component_name)) + + def to(self, torch_device: Optional[Union[str, torch.device]] = None): + if torch_device is None: + return self + + module_names, _, _ = self.extract_init_dict(dict(self.config)) + for name in module_names.keys(): + module = getattr(self, name) + if isinstance(module, torch.nn.Module): + if module.dtype == torch.float16 and str(torch_device) in ["cpu"]: + logger.warning( + "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It" + " is not recommended to move them to `cpu` as running them will fail. Please make" + " sure to use an accelerator to run the pipeline in inference, due to the lack of" + " support for`float16` operations on this device in PyTorch. Please, remove the" + " `torch_dtype=torch.float16` argument, or use another device for inference." + ) + module.to(torch_device) + return self + + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + module_names, _, _ = self.extract_init_dict(dict(self.config)) + for name in module_names.keys(): + module = getattr(self, name) + if isinstance(module, torch.nn.Module): + return module.device + return torch.device("cpu") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights. + + The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on + https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like + `CompVis/ldm-text2im-large-256`. + - A path to a *directory* containing pipeline weights saved using + [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + custom_pipeline (`str`, *optional*): + + + + This is an experimental feature and is likely to change in the future. + + + + Can be either: + + - A string, the *repo id* of a custom pipeline hosted inside a model repo on + https://huggingface.co/. Valid repo ids have to be located under a user or organization name, + like `hf-internal-testing/diffusers-dummy-pipeline`. + + + + It is required that the model repo has a file, called `pipeline.py` that defines the custom + pipeline. + + + + - A string, the *file name* of a community pipeline hosted on GitHub under + https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to + match exactly the file name without `.py` located under the above link, *e.g.* + `clip_guided_stable_diffusion`. + + + + Community pipelines are always loaded from the current `main` branch of GitHub. + + + + - A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`. + + + + It is required that the directory has a file, called `pipeline.py` that defines the custom + pipeline. + + + + For more information on how to load and create custom pipelines, please have a look at [Loading and + Adding Custom + Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview) + + torch_dtype (`str` or `torch.dtype`, *optional*): + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + custom_revision (`str`, *optional*, defaults to `"main"` when loading from the Hub and to local version of `diffusers` when loading from GitHub): + The specific model version to use. It can be a branch name, a tag name, or a commit id similar to + `revision` when loading a custom pipeline from the Hub. It can be a diffusers version when loading a + custom pipeline from GitHub. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. specify the folder name here. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, + setting this argument to `True` will raise an error. + return_cached_folder (`bool`, *optional*, defaults to `False`): + If set to `True`, path to downloaded cached folder will be returned in addition to loaded pipeline. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the + specific pipeline class. The overwritten components are then directly passed to the pipelines + `__init__` method. See example below for more information. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"` + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + + Examples: + + ```py + >>> from diffusers import DiffusionPipeline + + >>> # Download pipeline from huggingface.co and cache. + >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + >>> # Download pipeline that requires an authorization token + >>> # For more information on access tokens, please refer to this section + >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) + >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + + >>> # Use a different scheduler + >>> from diffusers import LMSDiscreteScheduler + + >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.scheduler = scheduler + ``` + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + custom_pipeline = kwargs.pop("custom_pipeline", None) + custom_revision = kwargs.pop("custom_revision", None) + provider = kwargs.pop("provider", None) + sess_options = kwargs.pop("sess_options", None) + device_map = kwargs.pop("device_map", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + return_cached_folder = kwargs.pop("return_cached_folder", False) + + # 1. Download the checkpoints and configs + # use snapshot download here to get it working from from_pretrained + if not os.path.isdir(pretrained_model_name_or_path): + config_dict = cls.load_config( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + # make sure we only download sub-folders and `diffusers` filenames + folder_names = [k for k in config_dict.keys() if not k.startswith("_")] + allow_patterns = [os.path.join(k, "*") for k in folder_names] + allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name] + + # make sure we don't download flax weights + ignore_patterns = ["*.msgpack"] + + if custom_pipeline is not None: + allow_patterns += [CUSTOM_PIPELINE_FILE_NAME] + + if cls != DiffusionPipeline: + requested_pipeline_class = cls.__name__ + else: + requested_pipeline_class = config_dict.get("_class_name", cls.__name__) + user_agent = {"pipeline_class": requested_pipeline_class} + if custom_pipeline is not None and not custom_pipeline.endswith(".py"): + user_agent["custom_pipeline"] = custom_pipeline + + user_agent = http_user_agent(user_agent) + + if is_safetensors_available(): + info = model_info( + pretrained_model_name_or_path, + use_auth_token=use_auth_token, + revision=revision, + ) + if is_safetensors_compatible(info): + ignore_patterns.append("*.bin") + else: + ignore_patterns.append("*.safetensors") + + # download all allow_patterns + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + user_agent=user_agent, + ) + else: + cached_folder = pretrained_model_name_or_path + + config_dict = cls.load_config(cached_folder) + + # 2. Load the pipeline class, if using custom module then load it from the hub + # if we load from explicit class, let's use it + if custom_pipeline is not None: + if custom_pipeline.endswith(".py"): + path = Path(custom_pipeline) + # decompose into folder & file + file_name = path.name + custom_pipeline = path.parent.absolute() + else: + file_name = CUSTOM_PIPELINE_FILE_NAME + + pipeline_class = get_class_from_dynamic_module( + custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=custom_revision + ) + elif cls != DiffusionPipeline: + pipeline_class = cls + else: + diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) + pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) + + # To be removed in 1.0.0 + if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( + version.parse(config_dict["_diffusers_version"]).base_version + ) <= version.parse("0.5.1"): + from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy + + pipeline_class = StableDiffusionInpaintPipelineLegacy + + deprecation_message = ( + "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the" + f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For" + " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting" + " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your" + f" checkpoint {pretrained_model_name_or_path} to the format of" + " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain" + " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0." + ) + deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False) + + # some modules can be passed directly to the init + # in this case they are already instantiated in `kwargs` + # extract them here + expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} + + init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + + # define init kwargs + init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict} + init_kwargs = {**init_kwargs, **passed_pipe_kwargs} + + # remove `null` components + def load_module(name, value): + if value[0] is None: + return False + if name in passed_class_obj and passed_class_obj[name] is None: + return False + return True + + init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} + + if len(unused_kwargs) > 0: + logger.warning( + f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." + ) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) + + # import it here to avoid circular import + from diffusers import pipelines + + # 3. Load each module in the pipeline + for name, (library_name, class_name) in init_dict.items(): + # 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names + if class_name.startswith("Flax"): + class_name = class_name[4:] + + is_pipeline_module = hasattr(pipelines, library_name) + loaded_sub_model = None + + # if the model is in a pipeline module, then we load it from the pipeline + if name in passed_class_obj: + # 1. check that passed_class_obj has correct parent class + if not is_pipeline_module: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} + + expected_class_obj = None + for class_name, class_candidate in class_candidates.items(): + if class_candidate is not None and issubclass(class_obj, class_candidate): + expected_class_obj = class_candidate + + if not issubclass(passed_class_obj[name].__class__, expected_class_obj): + raise ValueError( + f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" + f" {expected_class_obj}" + ) + else: + logger.warning( + f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" + " has the correct type" + ) + + # set passed class object + loaded_sub_model = passed_class_obj[name] + elif is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = getattr(pipeline_module, class_name) + importable_classes = ALL_IMPORTABLE_CLASSES + class_candidates = {c: class_obj for c in importable_classes.keys()} + else: + # else we just import it from the library. + library = importlib.import_module(library_name) + + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} + + if loaded_sub_model is None: + load_method_name = None + for class_name, class_candidate in class_candidates.items(): + if class_candidate is not None and issubclass(class_obj, class_candidate): + load_method_name = importable_classes[class_name][1] + + if load_method_name is None: + none_module = class_obj.__module__ + is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( + TRANSFORMERS_DUMMY_MODULES_FOLDER + ) + if is_dummy_path and "dummy" in none_module: + # call class_obj for nice error message of missing requirements + class_obj() + + raise ValueError( + f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" + f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." + ) + + load_method = getattr(class_obj, load_method_name) + loading_kwargs = {} + + if issubclass(class_obj, torch.nn.Module): + loading_kwargs["torch_dtype"] = torch_dtype + if issubclass(class_obj, diffusers.OnnxRuntimeModel): + loading_kwargs["provider"] = provider + loading_kwargs["sess_options"] = sess_options + + is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin) + is_transformers_model = ( + is_transformers_available() + and issubclass(class_obj, PreTrainedModel) + and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0") + ) + + # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. + # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. + # This makes sure that the weights won't be initialized which significantly speeds up loading. + if is_diffusers_model or is_transformers_model: + loading_kwargs["device_map"] = device_map + loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + + # check if the module is in a subdirectory + if os.path.isdir(os.path.join(cached_folder, name)): + loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) + else: + # else load from the root directory + loaded_sub_model = load_method(cached_folder, **loading_kwargs) + + init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) + + # 4. Potentially add passed objects if expected + missing_modules = set(expected_modules) - set(init_kwargs.keys()) + passed_modules = list(passed_class_obj.keys()) + optional_modules = pipeline_class._optional_components + if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules): + for module in missing_modules: + init_kwargs[module] = passed_class_obj.get(module, None) + elif len(missing_modules) > 0: + passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs + raise ValueError( + f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." + ) + + # 5. Instantiate the pipeline + model = pipeline_class(**init_kwargs) + + if return_cached_folder: + return model, cached_folder + return model + + @staticmethod + def _get_signature_keys(obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - set(["self"]) + return expected_modules, optional_parameters + + @property + def components(self) -> Dict[str, Any]: + r""" + + The `self.components` property can be useful to run different pipelines with the same weights and + configurations to not have to re-allocate memory. + + Examples: + + ```py + >>> from diffusers import ( + ... StableDiffusionPipeline, + ... StableDiffusionImg2ImgPipeline, + ... StableDiffusionInpaintPipeline, + ... ) + + >>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components) + >>> inpaint = StableDiffusionInpaintPipeline(**text2img.components) + ``` + + Returns: + A dictionary containing all the modules needed to initialize the pipeline. + """ + expected_modules, optional_parameters = self._get_signature_keys(self) + components = { + k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters + } + + if set(components.keys()) != expected_modules: + raise ValueError( + f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" + f" {expected_modules} to be defined, but {components} are defined." + ) + + return components + + @staticmethod + def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + def progress_bar(self, iterable=None, total=None): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + if iterable is not None: + return tqdm(iterable, **self._progress_bar_config) + elif total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.set_use_memory_efficient_attention_xformers(False) + + def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + module_names, _, _ = self.extract_init_dict(dict(self.config)) + for module_name in module_names: + module = getattr(self, module_name) + if isinstance(module, torch.nn.Module): + fn_recursive_set_mem_eff(module) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + self.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def set_attention_slice(self, slice_size: Optional[int]): + module_names, _, _ = self.extract_init_dict(dict(self.config)) + for module_name in module_names: + module = getattr(self, module_name) + if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size) diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index 020dd7c4ee79..34204b124bf6 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -18,8 +18,8 @@ import torch from ...models import UNet2DModel -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import PNDMScheduler +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class PNDMPipeline(DiffusionPipeline): @@ -62,12 +62,11 @@ def __call__( output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a - [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # For more information on the sampling method you can take a look at Algorithm 2 of # the official paper: https://arxiv.org/pdf/2202.09778.pdf diff --git a/src/diffusers/pipelines/repaint/pipeline_repaint.py b/src/diffusers/pipelines/repaint/pipeline_repaint.py index a316e5bc8182..32374e9a310b 100644 --- a/src/diffusers/pipelines/repaint/pipeline_repaint.py +++ b/src/diffusers/pipelines/repaint/pipeline_repaint.py @@ -21,9 +21,9 @@ import PIL from ...models import UNet2DModel -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import RePaintScheduler from ...utils import PIL_INTERPOLATION, deprecate, logging +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -118,12 +118,11 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ message = "Please use `image` instead of `original_image`." diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 8dc7001a4a69..a53d0840b137 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -17,8 +17,8 @@ import torch from ...models import UNet2DModel -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import ScoreSdeVeScheduler +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class ScoreSdeVePipeline(DiffusionPipeline): @@ -57,12 +57,11 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ img_size = self.unet.config.sample_size diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index 3e8917bd3219..4392d9d8058e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -25,9 +25,9 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler from ...utils import PIL_INTERPOLATION, deprecate, logging +from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 09760cef2d60..d5add17107b9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -28,7 +28,6 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel -from ...pipeline_flax_utils import FlaxDiffusionPipeline from ...schedulers import ( FlaxDDIMScheduler, FlaxDPMSolverMultistepScheduler, @@ -36,6 +35,7 @@ FlaxPNDMScheduler, ) from ...utils import deprecate, logging +from ..pipeline_flax_utils import FlaxDiffusionPipeline from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker @@ -184,18 +184,14 @@ def _generate( self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.PRNGKey, - num_inference_steps: int = 50, - height: Optional[int] = None, - width: Optional[int] = None, - guidance_scale: float = 7.5, + prng_seed: jax.random.KeyArray, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, latents: Optional[jnp.array] = None, - neg_prompt_ids: jnp.array = None, + neg_prompt_ids: Optional[jnp.array] = None, ): - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -281,15 +277,15 @@ def __call__( self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.PRNGKey, + prng_seed: jax.random.KeyArray, num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, guidance_scale: Union[float, jnp.array] = 7.5, latents: jnp.array = None, + neg_prompt_ids: jnp.array = None, return_dict: bool = True, jit: bool = False, - neg_prompt_ids: jnp.array = None, ): r""" Function invoked when calling the pipeline for generation. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index 7b0b35f89e00..16ac2ba155d2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -14,7 +14,7 @@ import warnings from functools import partial -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union import numpy as np @@ -27,7 +27,6 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel -from ...pipeline_flax_utils import FlaxDiffusionPipeline from ...schedulers import ( FlaxDDIMScheduler, FlaxDPMSolverMultistepScheduler, @@ -35,12 +34,16 @@ FlaxPNDMScheduler, ) from ...utils import PIL_INTERPOLATION, logging +from ..pipeline_flax_utils import FlaxDiffusionPipeline from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): r""" @@ -106,6 +109,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]): if not isinstance(prompt, (str, list)): @@ -116,10 +120,8 @@ def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image if isinstance(image, Image.Image): image = [image] - processed_image = [] - for img in image: - processed_image.append(preprocess(img, self.dtype)) - processed_image = jnp.array(processed_image).squeeze() + + processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) text_input = self.tokenizer( prompt, @@ -128,7 +130,7 @@ def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image truncation=True, return_tensors="np", ) - return text_input.input_ids, processed_image + return text_input.input_ids, processed_images def _get_has_nsfw_concepts(self, features, params): has_nsfw_concepts = self.safety_checker(features, params) @@ -164,12 +166,11 @@ def _run_safety_checker(self, images, safety_model_params, jit=False): return images, has_nsfw_concepts - def get_timestep_start(self, num_inference_steps, strength, scheduler_state): + def get_timestep_start(self, num_inference_steps, strength): # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - t_start = max(num_inference_steps - init_timestep + offset, 0) + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) return t_start @@ -178,13 +179,14 @@ def _generate( prompt_ids: jnp.array, image: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.PRNGKey, - strength: float = 0.8, - num_inference_steps: int = 50, - height: int = 512, - width: int = 512, - guidance_scale: float = 7.5, - debug: bool = False, + prng_seed: jax.random.KeyArray, + start_timestep: int, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, + noise: Optional[jnp.array] = None, + neg_prompt_ids: Optional[jnp.array] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -197,18 +199,32 @@ def _generate( batch_size = prompt_ids.shape[0] max_length = prompt_ids.shape[-1] - uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" - ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([uncond_embeddings, text_embeddings]) + latents_shape = ( + batch_size, + self.unet.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if noise is None: + noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if noise.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {noise.shape}, expected {latents_shape}") + # Create init_latents init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2)) init_latents = 0.18215 * init_latents - latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) - noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) def loop_body(step, args): latents, scheduler_state = args @@ -241,19 +257,19 @@ def loop_body(step, args): params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape ) - t_start = self.get_timestep_start(num_inference_steps, strength, scheduler_state) - latent_timestep = scheduler_state.timesteps[t_start : t_start + 1].repeat(batch_size) - init_latents = self.scheduler.add_noise(init_latents, noise, latent_timestep) - latents = init_latents + latent_timestep = scheduler_state.timesteps[start_timestep : start_timestep + 1].repeat(batch_size) - if debug: + latents = self.scheduler.add_noise(params["scheduler"], init_latents, noise, latent_timestep) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + if DEBUG: # run with python for loop - for i in range(t_start, len(scheduler_state.timesteps)): + for i in range(start_timestep, num_inference_steps): latents, scheduler_state = loop_body(i, (latents, scheduler_state)) else: - latents, _ = jax.lax.fori_loop( - t_start, len(scheduler_state.timesteps), loop_body, (latents, scheduler_state) - ) + latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state)) # scale and decode the image latents with vae latents = 1 / 0.18215 * latents @@ -268,14 +284,15 @@ def __call__( image: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.KeyArray, - num_inference_steps: int = 50, - height: int = 512, - width: int = 512, - guidance_scale: float = 7.5, strength: float = 0.8, + num_inference_steps: int = 50, + height: Optional[int] = None, + width: Optional[int] = None, + guidance_scale: Union[float, jnp.array] = 7.5, + noise: jnp.array = None, + neg_prompt_ids: jnp.array = None, return_dict: bool = True, jit: bool = False, - debug: bool = False, ): r""" Function invoked when calling the pipeline for generation. @@ -287,12 +304,17 @@ def __call__( Array representing an image batch, that will be used as the starting point for the process. params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). @@ -300,18 +322,17 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` - will be used as a starting point, adding more noise to it the larger the `strength`. The number of - denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will - be maximum and the denoising process will run for the full number of iterations specified in + noise (`jnp.array`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. tensor will ge generated + by sampling using the supplied random `generator`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. jit (`bool`, defaults to `False`): Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. - debug (`bool`, *optional*, defaults to `False`): Whether to make use of python forloop or lax.fori_loop + Returns: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a @@ -319,76 +340,109 @@ def __call__( element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale[:, None] + + start_timestep = self.get_timestep_start(num_inference_steps, strength) + if jit: - image = _p_generate( + images = _p_generate( self, prompt_ids, image, params, prng_seed, - strength, + start_timestep, num_inference_steps, height, width, guidance_scale, - debug, + noise, + neg_prompt_ids, ) else: - image = self._generate( + images = self._generate( prompt_ids, image, params, prng_seed, - strength, + start_timestep, num_inference_steps, height, width, guidance_scale, - debug, + noise, + neg_prompt_ids, ) if self.safety_checker is not None: safety_params = params["safety_checker"] - image_uint8_casted = (image * 255).round().astype("uint8") - num_devices, batch_size = image.shape[:2] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] - image_uint8_casted = np.asarray(image_uint8_casted).reshape(num_devices * batch_size, height, width, 3) - image_uint8_casted, has_nsfw_concept = self._run_safety_checker(image_uint8_casted, safety_params, jit) - image = np.asarray(image) + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.asarray(images) # block images if any(has_nsfw_concept): for i, is_nsfw in enumerate(has_nsfw_concept): if is_nsfw: - image[i] = np.asarray(image_uint8_casted[i]) + images[i] = np.asarray(images_uint8_casted[i]) - image = image.reshape(num_devices, batch_size, height, width, 3) + images = images.reshape(num_devices, batch_size, height, width, 3) else: + images = np.asarray(images) has_nsfw_concept = False if not return_dict: - return (image, has_nsfw_concept) + return (images, has_nsfw_concept) - return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) -# TODO: maybe use a config dict instead of so many static argnums -@partial(jax.pmap, static_broadcasted_argnums=(0, 5, 6, 7, 8, 9, 10)) +# Static argnums are pipe, start_timestep, num_inference_steps, height, width. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, 0, None, None, None, None, 0, 0, 0), + static_broadcasted_argnums=(0, 5, 6, 7, 8), +) def _p_generate( pipe, prompt_ids, image, params, prng_seed, - strength, + start_timestep, num_inference_steps, height, width, guidance_scale, - debug, + noise, + neg_prompt_ids, ): return pipe._generate( - prompt_ids, image, params, prng_seed, strength, num_inference_steps, height, width, guidance_scale, debug + prompt_ids, + image, + params, + prng_seed, + start_timestep, + num_inference_steps, + height, + width, + guidance_scale, + noise, + neg_prompt_ids, ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py index 1b9a8ff724a4..d7ae6aa8865e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -21,10 +21,10 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer from ...configuration_utils import FrozenDict -from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel -from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import deprecate, logging +from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py index 71b5fdbbebce..6e216a675f1a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py @@ -22,10 +22,10 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer from ...configuration_utils import FrozenDict -from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel -from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import PIL_INTERPOLATION, deprecate, logging +from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py index 930d61de99cc..f7bb440a534c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -22,10 +22,10 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer from ...configuration_utils import FrozenDict -from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel -from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import PIL_INTERPOLATION, deprecate, logging +from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py index d32586febbbd..98914eaa25c7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py @@ -8,10 +8,10 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer from ...configuration_utils import FrozenDict -from ...onnx_utils import OnnxRuntimeModel -from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import deprecate, logging +from ..onnx_utils import OnnxRuntimeModel +from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 37a041bd33f0..b9b87e8f1425 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -23,7 +23,6 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, @@ -33,6 +32,7 @@ PNDMScheduler, ) from ...utils import deprecate, logging +from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index dfd3d2339384..e86f2d7226ba 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -26,7 +26,6 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, @@ -36,6 +35,7 @@ PNDMScheduler, ) from ...utils import PIL_INTERPOLATION, deprecate, logging +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index 975afddfe003..fd2d4afb4bde 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -24,7 +24,6 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, @@ -34,6 +33,7 @@ PNDMScheduler, ) from ...utils import deprecate, logging +from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 857ad6f3d507..0213a1a563de 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -25,7 +25,6 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, @@ -35,6 +34,7 @@ PNDMScheduler, ) from ...utils import PIL_INTERPOLATION, deprecate, logging +from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 985b138dc0ab..77e1df6e1d11 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -25,9 +25,9 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import deprecate, logging +from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 429e2da286ed..40e026e2166b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -25,7 +25,6 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, @@ -35,6 +34,7 @@ PNDMScheduler, ) from ...utils import PIL_INTERPOLATION, deprecate, logging +from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 1bb0cb051ad8..c39152721fe0 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -19,7 +19,7 @@ from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser -from ... import DiffusionPipeline +from ...pipelines import DiffusionPipeline from ...schedulers import LMSDiscreteScheduler from ...utils import is_accelerate_available, logging from . import StableDiffusionPipelineOutput diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 528ed0dca3c0..236ecea59fa4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -23,9 +23,9 @@ from transformers import CLIPTextModel, CLIPTokenizer from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import logging +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py index e1f669d22b76..71b7306134a5 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -87,7 +87,7 @@ def __init__( module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensor clip_input = jax.random.normal(rng, input_shape) diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index d5b163ba46ef..c7b58bbfb5fd 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -10,7 +10,6 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, @@ -20,6 +19,7 @@ PNDMScheduler, ) from ...utils import deprecate, is_accelerate_available, logging +from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionSafePipelineOutput from .safety_checker import SafeStableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py index 8da1faf9c63d..90f868371bb6 100644 --- a/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +++ b/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py @@ -17,8 +17,8 @@ import torch from ...models import UNet2DModel -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import KarrasVeScheduler +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class KarrasVePipeline(DiffusionPipeline): @@ -68,12 +68,11 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ img_size = self.unet.config.sample_size diff --git a/src/diffusers/pipelines/unclip/__init__.py b/src/diffusers/pipelines/unclip/__init__.py index c495367bc770..23b54a7d2f79 100644 --- a/src/diffusers/pipelines/unclip/__init__.py +++ b/src/diffusers/pipelines/unclip/__init__.py @@ -13,4 +13,5 @@ from ...utils.dummy_torch_and_transformers_objects import UnCLIPPipeline else: from .pipeline_unclip import UnCLIPPipeline + from .pipeline_unclip_image_variation import UnCLIPImageVariationPipeline from .text_proj import UnCLIPTextProjModel diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py index 5dc8ed3a89e9..0f35d004a09a 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py @@ -13,16 +13,17 @@ # limitations under the License. import inspect -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import torch from torch.nn import functional as F -from diffusers import PriorTransformer, UNet2DConditionModel, UNet2DModel -from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from diffusers.schedulers import UnCLIPScheduler from transformers import CLIPTextModelWithProjection, CLIPTokenizer +from transformers.models.clip.modeling_clip import CLIPTextModelOutput +from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel +from ...pipelines import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import UnCLIPScheduler from ...utils import is_accelerate_available, logging from .text_proj import UnCLIPTextProjModel @@ -45,6 +46,8 @@ class UnCLIPPipeline(DiffusionPipeline): [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). prior ([`PriorTransformer`]): The canonincal unCLIP prior to approximate the image embedding from the text embedding. + text_proj ([`UnCLIPTextProjModel`]): + Utility class to prepare and combine the embeddings before they are passed to the decoder. decoder ([`UNet2DConditionModel`]): The decoder to invert the image embedding into an image. super_res_first ([`UNet2DModel`]): @@ -115,31 +118,44 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): latents = latents * scheduler.init_noise_sigma return latents - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - text_mask = text_inputs.attention_mask.bool().to(device) - - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, + ): + if text_model_output is None: + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_encoder_output = self.text_encoder(text_input_ids.to(device)) + text_encoder_output = self.text_encoder(text_input_ids.to(device)) - text_embeddings = text_encoder_output.text_embeds - text_encoder_hidden_states = text_encoder_output.last_hidden_state + text_embeddings = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + else: + batch_size = text_model_output[0].shape[0] + text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1] + text_mask = text_attention_mask text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) @@ -148,11 +164,10 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr if do_classifier_free_guidance: uncond_tokens = [""] * batch_size - max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", - max_length=max_length, + max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) @@ -233,7 +248,7 @@ def _execution_device(self): @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], + prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, prior_num_inference_steps: int = 25, decoder_num_inference_steps: int = 25, @@ -242,6 +257,8 @@ def __call__( prior_latents: Optional[torch.FloatTensor] = None, decoder_latents: Optional[torch.FloatTensor] = None, super_res_latents: Optional[torch.FloatTensor] = None, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, prior_guidance_scale: float = 4.0, decoder_guidance_scale: float = 8.0, output_type: Optional[str] = "pil", @@ -252,7 +269,8 @@ def __call__( Args: prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. + The prompt or prompts to guide the image generation. This can only be left undefined if + `text_model_output` and `text_attention_mask` is passed. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. prior_num_inference_steps (`int`, *optional*, defaults to 25): @@ -285,18 +303,29 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + text_model_output (`CLIPTextModelOutput`, *optional*): + Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs + can be passed for tasks like text embedding interpolations. Make sure to also pass + `text_attention_mask` in this case. `prompt` can the be left to `None`. + text_attention_mask (`torch.Tensor`, *optional*): + Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention + masks are necessary when passing `text_model_output`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. """ - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) + if prompt is not None: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + batch_size = text_model_output[0].shape[0] + device = self._execution_device batch_size = batch_size * num_images_per_prompt @@ -304,7 +333,7 @@ def __call__( do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance + prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask ) # prior @@ -313,6 +342,7 @@ def __call__( prior_timesteps_tensor = self.prior_scheduler.timesteps embedding_dim = self.prior.config.embedding_dim + prior_latents = self.prepare_latents( (batch_size, embedding_dim), text_embeddings.dtype, @@ -376,6 +406,7 @@ def __call__( num_channels_latents = self.decoder.in_channels height = self.decoder.sample_size width = self.decoder.sample_size + decoder_latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), text_encoder_hidden_states.dtype, @@ -428,6 +459,7 @@ def __call__( channels = self.super_res_first.in_channels // 2 height = self.super_res_first.sample_size width = self.super_res_first.sample_size + super_res_latents = self.prepare_latents( (batch_size, channels, height, width), image_small.dtype, diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py new file mode 100644 index 000000000000..0b83407d8ccd --- /dev/null +++ b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py @@ -0,0 +1,457 @@ +# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Union + +import torch +from torch.nn import functional as F + +import PIL +from transformers import ( + CLIPFeatureExtractor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...models import UNet2DConditionModel, UNet2DModel +from ...pipelines import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import UnCLIPScheduler +from ...utils import is_accelerate_available, logging +from .text_proj import UnCLIPTextProjModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class UnCLIPImageVariationPipeline(DiffusionPipeline): + """ + Pipeline to generate variations from an input image using unCLIP + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `image_encoder`. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_proj ([`UnCLIPTextProjModel`]): + Utility class to prepare and combine the embeddings before they are passed to the decoder. + decoder ([`UNet2DConditionModel`]): + The decoder to invert the image embedding into an image. + super_res_first ([`UNet2DModel`]): + Super resolution unet. Used in all but the last step of the super resolution diffusion process. + super_res_last ([`UNet2DModel`]): + Super resolution unet. Used in the last step of the super resolution diffusion process. + decoder_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the decoder denoising process. Just a modified DDPMScheduler. + super_res_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler. + + """ + + decoder: UNet2DConditionModel + text_proj: UnCLIPTextProjModel + text_encoder: CLIPTextModelWithProjection + tokenizer: CLIPTokenizer + feature_extractor: CLIPFeatureExtractor + image_encoder: CLIPVisionModelWithProjection + super_res_first: UNet2DModel + super_res_last: UNet2DModel + + decoder_scheduler: UnCLIPScheduler + super_res_scheduler: UnCLIPScheduler + + def __init__( + self, + decoder: UNet2DConditionModel, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_proj: UnCLIPTextProjModel, + feature_extractor: CLIPFeatureExtractor, + image_encoder: CLIPVisionModelWithProjection, + super_res_first: UNet2DModel, + super_res_last: UNet2DModel, + decoder_scheduler: UnCLIPScheduler, + super_res_scheduler: UnCLIPScheduler, + ): + super().__init__() + + self.register_modules( + decoder=decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_proj=text_proj, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + super_res_first=super_res_first, + super_res_last=super_res_last, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + text_embeddings = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return text_embeddings, text_encoder_hidden_states, text_mask + + def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None): + dtype = next(self.image_encoder.parameters()).dtype + + if image_embeddings is None: + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + + image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + + return image_embeddings + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [ + self.decoder, + self.text_proj, + self.text_encoder, + self.super_res_first, + self.super_res_last, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"): + return self.device + for module in self.decoder.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + def __call__( + self, + image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None, + num_images_per_prompt: int = 1, + decoder_num_inference_steps: int = 25, + super_res_num_inference_steps: int = 7, + generator: Optional[torch.Generator] = None, + decoder_latents: Optional[torch.FloatTensor] = None, + super_res_latents: Optional[torch.FloatTensor] = None, + image_embeddings: Optional[torch.Tensor] = None, + decoder_guidance_scale: float = 8.0, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The image or images to guide the image generation. If you provide a tensor, it needs to comply with the + configuration of + [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) + `CLIPFeatureExtractor`. Can be left to `None` only when `image_embeddings` are passed. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + decoder_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + super_res_num_inference_steps (`int`, *optional*, defaults to 7): + The number of denoising steps for super resolution. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + decoder_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + image_embeddings (`torch.Tensor`, *optional*): + Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings + can be passed for tasks like image interpolations. `image` can the be left to `None`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + """ + if image is not None: + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + else: + batch_size = image_embeddings.shape[0] + + prompt = [""] * batch_size + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = decoder_guidance_scale > 1.0 + + text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance + ) + + image_embeddings = self._encode_image(image, device, num_images_per_prompt, image_embeddings) + + # decoder + text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( + image_embeddings=image_embeddings, + text_embeddings=text_embeddings, + text_encoder_hidden_states=text_encoder_hidden_states, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) + + self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) + decoder_timesteps_tensor = self.decoder_scheduler.timesteps + + num_channels_latents = self.decoder.in_channels + height = self.decoder.sample_size + width = self.decoder.sample_size + + if decoder_latents is None: + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + decoder_latents, + self.decoder_scheduler, + ) + + for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents + + noise_pred = self.decoder( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + class_labels=additive_clip_time_embeddings, + attention_mask=decoder_text_mask, + ).sample + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if i + 1 == decoder_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = decoder_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + decoder_latents = self.decoder_scheduler.step( + noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + decoder_latents = decoder_latents.clamp(-1, 1) + + image_small = decoder_latents + + # done decoder + + # super res + + self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) + super_res_timesteps_tensor = self.super_res_scheduler.timesteps + + channels = self.super_res_first.in_channels // 2 + height = self.super_res_first.sample_size + width = self.super_res_first.sample_size + + if super_res_latents is None: + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + super_res_latents, + self.super_res_scheduler, + ) + + interpolate_antialias = {} + if "antialias" in inspect.signature(F.interpolate).parameters: + interpolate_antialias["antialias"] = True + + image_upscaled = F.interpolate( + image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias + ) + + for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): + # no classifier free guidance + + if i == super_res_timesteps_tensor.shape[0] - 1: + unet = self.super_res_last + else: + unet = self.super_res_first + + latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) + + noise_pred = unet( + sample=latent_model_input, + timestep=t, + ).sample + + if i + 1 == super_res_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = super_res_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + super_res_latents = self.super_res_scheduler.step( + noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + image = super_res_latents + + # done super res + + # post processing + + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/unclip/text_proj.py b/src/diffusers/pipelines/unclip/text_proj.py index 31e7e0644e83..4482010519a2 100644 --- a/src/diffusers/pipelines/unclip/text_proj.py +++ b/src/diffusers/pipelines/unclip/text_proj.py @@ -15,9 +15,8 @@ import torch from torch import nn -from diffusers.modeling_utils import ModelMixin - from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin class UnCLIPTextProjModel(ModelMixin, ConfigMixin): diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 3d3f210c4183..40c6225c96a0 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -5,7 +5,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...modeling_utils import ModelMixin +from ...models import ModelMixin from ...models.attention import CrossAttention, DualTransformer2DModel, Transformer2DModel from ...models.cross_attention import AttnProcessor, CrossAttnAddedKVProcessor from ...models.embeddings import TimestepEmbedding, Timesteps diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py index d0202922da04..88e7e4b6a49f 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py @@ -7,9 +7,9 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import logging +from ..pipeline_utils import DiffusionPipeline from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index b7c22494df0f..74902665ead9 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -27,11 +27,10 @@ CLIPVisionModelWithProjection, ) -from ...models import AutoencoderKL, UNet2DConditionModel -from ...models.attention import DualTransformer2DModel, Transformer2DModel -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import is_accelerate_available, logging +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .modeling_text_unet import UNetFlatConditionModel diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index 122fd7a9f72f..93c70688aec3 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -23,9 +23,9 @@ from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import is_accelerate_available, logging +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index e12eabded73a..e05cb036a8ea 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -20,11 +20,10 @@ from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer -from ...models import AutoencoderKL, UNet2DConditionModel -from ...models.attention import Transformer2DModel -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import is_accelerate_available, logging +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .modeling_text_unet import UNetFlatConditionModel diff --git a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py index a536f91d9ec8..bd63eda030cc 100644 --- a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py +++ b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -16,14 +16,13 @@ import torch -from diffusers import Transformer2DModel, VQModel -from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler from transformers import CLIPTextModel, CLIPTokenizer from ...configuration_utils import ConfigMixin, register_to_config -from ...modeling_utils import ModelMixin -from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...models import ModelMixin, Transformer2DModel, VQModel +from ...schedulers import VQDiffusionScheduler from ...utils import logging +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -212,7 +211,7 @@ def __call__( The output format of the generated image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. @@ -221,9 +220,8 @@ def __call__( called at every step. Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the - generated images. + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if `return_dict` + is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ if isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 70cf22654873..c332bcf54b9e 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -201,7 +201,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + self.num_inference_steps = num_inference_steps + step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 86edcb441fcb..7c300d4a42c1 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -184,11 +184,18 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ - num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + self.num_inference_steps = num_inference_steps - timesteps = np.arange( - 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps - )[::-1].copy() + + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t, predicted_variance=None, variance_type=None): diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index e5a4d323e3eb..332a9917010c 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -18,7 +18,22 @@ from packaging import version from .. import __version__ +from .constants import ( + _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + CONFIG_NAME, + DIFFUSERS_CACHE, + DIFFUSERS_DYNAMIC_MODULE_NAME, + FLAX_WEIGHTS_NAME, + HF_MODULES_CACHE, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + ONNX_EXTERNAL_WEIGHTS_NAME, + ONNX_WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, +) from .deprecation_utils import deprecate +from .dynamic_modules_utils import get_class_from_dynamic_module +from .hub_utils import HF_HUB_OFFLINE, http_user_agent from .import_utils import ( ENV_VARS_TRUE_AND_AUTO_VALUES, ENV_VARS_TRUE_VALUES, @@ -67,36 +82,6 @@ logger = get_logger(__name__) -hf_cache_home = os.path.expanduser( - os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) -) -default_cache_path = os.path.join(hf_cache_home, "diffusers") - - -CONFIG_NAME = "config.json" -WEIGHTS_NAME = "diffusion_pytorch_model.bin" -FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" -ONNX_WEIGHTS_NAME = "model.onnx" -SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" -ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" -HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" -DIFFUSERS_CACHE = default_cache_path -DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" -HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) - -_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [ - "DDIMScheduler", - "DDPMScheduler", - "PNDMScheduler", - "LMSDiscreteScheduler", - "EulerDiscreteScheduler", - "HeunDiscreteScheduler", - "EulerAncestralDiscreteScheduler", - "DPMSolverMultistepScheduler", - "DPMSolverSinglestepScheduler", -] - - def check_min_version(min_version): if version.parse(__version__) < version.parse(min_version): if "dev" in min_version: diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py new file mode 100644 index 000000000000..eaa8212298a0 --- /dev/null +++ b/src/diffusers/utils/constants.py @@ -0,0 +1,44 @@ +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + + +hf_cache_home = os.path.expanduser( + os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) +) +default_cache_path = os.path.join(hf_cache_home, "diffusers") + + +CONFIG_NAME = "config.json" +WEIGHTS_NAME = "diffusion_pytorch_model.bin" +FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" +ONNX_WEIGHTS_NAME = "model.onnx" +SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" +ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" +HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" +DIFFUSERS_CACHE = default_cache_path +DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" +HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) + +_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [ + "DDIMScheduler", + "DDPMScheduler", + "PNDMScheduler", + "LMSDiscreteScheduler", + "EulerDiscreteScheduler", + "HeunDiscreteScheduler", + "EulerAncestralDiscreteScheduler", + "DPMSolverMultistepScheduler", + "DPMSolverSinglestepScheduler", +] diff --git a/src/diffusers/utils/dummy_onnx_objects.py b/src/diffusers/utils/dummy_onnx_objects.py new file mode 100644 index 000000000000..963906b24c36 --- /dev/null +++ b/src/diffusers/utils/dummy_onnx_objects.py @@ -0,0 +1,19 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa + +from ..utils import DummyObject, requires_backends + + +class OnnxRuntimeModel(metaclass=DummyObject): + _backends = ["onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["onnx"]) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 615f84d115bf..63a7d258a902 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -4,7 +4,7 @@ from ..utils import DummyObject, requires_backends -class ModelMixin(metaclass=DummyObject): +class AutoencoderKL(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -19,7 +19,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class AutoencoderKL(metaclass=DummyObject): +class ModelMixin(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -152,7 +152,7 @@ def get_scheduler(*args, **kwargs): requires_backends(get_scheduler, ["torch"]) -class DiffusionPipeline(metaclass=DummyObject): +class AudioPipelineOutput(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -212,6 +212,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DiffusionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ImagePipelineOutput(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class KarrasVePipeline(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index ba2798c784ef..25f347be9021 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -199,6 +199,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class UnCLIPImageVariationPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class UnCLIPPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py similarity index 99% rename from src/diffusers/dynamic_modules_utils.py rename to src/diffusers/utils/dynamic_modules_utils.py index 693d9811fc83..464257bd7b35 100644 --- a/src/diffusers/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -28,8 +28,8 @@ from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info -from . import __version__ -from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging +from .. import __version__ +from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging COMMUNITY_PIPELINES_URL = ( @@ -172,7 +172,7 @@ def find_pipeline_class(loaded_module): Retrieve pipeline class that inherits from `DiffusionPipeline`. Note that there has to be exactly one class inheriting from `DiffusionPipeline`. """ - from .pipeline_utils import DiffusionPipeline + from ..pipelines import DiffusionPipeline cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass)) diff --git a/src/diffusers/hub_utils.py b/src/diffusers/utils/hub_utils.py similarity index 96% rename from src/diffusers/hub_utils.py rename to src/diffusers/utils/hub_utils.py index 33e176356667..e22b3644fb84 100644 --- a/src/diffusers/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -22,9 +22,10 @@ from huggingface_hub import HfFolder, whoami -from . import __version__ -from .utils import ENV_VARS_TRUE_VALUES, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging -from .utils.import_utils import ( +from .. import __version__ +from .constants import HUGGINGFACE_CO_RESOLVE_ENDPOINT +from .import_utils import ( + ENV_VARS_TRUE_VALUES, _flax_version, _jax_version, _onnxruntime_version, @@ -34,13 +35,14 @@ is_onnx_available, is_torch_available, ) +from .logging import get_logger if is_modelcards_available(): from modelcards import CardData, ModelCard -logger = logging.get_logger(__name__) +logger = get_logger(__name__) MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md" diff --git a/tests/fixtures/custom_pipeline/pipeline.py b/tests/fixtures/custom_pipeline/pipeline.py index e7429d0a1945..0667edcfc62a 100644 --- a/tests/fixtures/custom_pipeline/pipeline.py +++ b/tests/fixtures/custom_pipeline/pipeline.py @@ -18,7 +18,7 @@ import torch -from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers import DiffusionPipeline, ImagePipelineOutput class CustomLocalPipeline(DiffusionPipeline): @@ -63,10 +63,10 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: - [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index 2948151e3d00..75481ecbbb8c 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -19,7 +19,7 @@ import torch from diffusers import AutoencoderKL -from diffusers.modeling_utils import ModelMixin +from diffusers.models import ModelMixin from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slow, torch_all_close, torch_device from parameterized import parameterized diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py index 88157f22de6b..756442600ef1 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py @@ -228,7 +228,6 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): model_id, safety_checker=None, scheduler=pndm, - device_map="auto", torch_dtype=torch.float16, ) pipe.to(torch_device) @@ -244,7 +243,7 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): image=init_image, mask_image=mask_image, generator=generator, - num_inference_steps=5, + num_inference_steps=2, output_type="np", ) diff --git a/tests/pipelines/unclip/test_unclip.py b/tests/pipelines/unclip/test_unclip.py index c1f67e557fd9..670082c20c24 100644 --- a/tests/pipelines/unclip/test_unclip.py +++ b/tests/pipelines/unclip/test_unclip.py @@ -248,6 +248,120 @@ def test_unclip(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + def test_unclip_passed_text_embed(self): + device = torch.device("cpu") + + class DummyScheduler: + init_noise_sigma = 1 + + prior = self.dummy_prior + decoder = self.dummy_decoder + text_proj = self.dummy_text_proj + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + super_res_first = self.dummy_super_res_first + super_res_last = self.dummy_super_res_last + + prior_scheduler = UnCLIPScheduler( + variance_type="fixed_small_log", + prediction_type="sample", + num_train_timesteps=1000, + clip_sample_range=5.0, + ) + + decoder_scheduler = UnCLIPScheduler( + variance_type="learned_range", + prediction_type="epsilon", + num_train_timesteps=1000, + ) + + super_res_scheduler = UnCLIPScheduler( + variance_type="fixed_small_log", + prediction_type="epsilon", + num_train_timesteps=1000, + ) + + pipe = UnCLIPPipeline( + prior=prior, + decoder=decoder, + text_proj=text_proj, + text_encoder=text_encoder, + tokenizer=tokenizer, + super_res_first=super_res_first, + super_res_last=super_res_last, + prior_scheduler=prior_scheduler, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + pipe = pipe.to(device) + + generator = torch.Generator(device=device).manual_seed(0) + dtype = prior.dtype + batch_size = 1 + + shape = (batch_size, prior.config.embedding_dim) + prior_latents = pipe.prepare_latents( + shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() + ) + shape = (batch_size, decoder.in_channels, decoder.sample_size, decoder.sample_size) + decoder_latents = pipe.prepare_latents( + shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() + ) + + shape = ( + batch_size, + super_res_first.in_channels // 2, + super_res_first.sample_size, + super_res_first.sample_size, + ) + super_res_latents = pipe.prepare_latents( + shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() + ) + + pipe.set_progress_bar_config(disable=None) + + prompt = "this is a prompt example" + + generator = torch.Generator(device=device).manual_seed(0) + output = pipe( + [prompt], + generator=generator, + prior_num_inference_steps=2, + decoder_num_inference_steps=2, + super_res_num_inference_steps=2, + prior_latents=prior_latents, + decoder_latents=decoder_latents, + super_res_latents=super_res_latents, + output_type="np", + ) + image = output.images + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ) + text_model_output = text_encoder(text_inputs.input_ids) + text_attention_mask = text_inputs.attention_mask + + generator = torch.Generator(device=device).manual_seed(0) + image_from_text = pipe( + generator=generator, + prior_num_inference_steps=2, + decoder_num_inference_steps=2, + super_res_num_inference_steps=2, + prior_latents=prior_latents, + decoder_latents=decoder_latents, + super_res_latents=super_res_latents, + text_model_output=text_model_output, + text_attention_mask=text_attention_mask, + output_type="np", + )[0] + + # make sure passing text embeddings manually is identical + assert np.abs(image - image_from_text).max() < 1e-4 + @slow @require_torch_gpu @@ -281,7 +395,7 @@ def test_unclip_karlo(self): assert image.shape == (256, 256, 3) assert np.abs(expected_image - image).max() < 1e-2 - def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): + def test_unclip_pipeline_with_sequential_cpu_offloading(self): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py new file mode 100644 index 000000000000..87ad14146a11 --- /dev/null +++ b/tests/pipelines/unclip/test_unclip_image_variation.py @@ -0,0 +1,494 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch + +from diffusers import ( + DiffusionPipeline, + UnCLIPImageVariationPipeline, + UnCLIPScheduler, + UNet2DConditionModel, + UNet2DModel, +) +from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel +from diffusers.utils import floats_tensor, load_numpy, slow, torch_device +from diffusers.utils.testing_utils import load_image, require_torch_gpu +from transformers import ( + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def text_embedder_hidden_size(self): + return 32 + + @property + def time_input_dim(self): + return 32 + + @property + def block_out_channels_0(self): + return self.time_input_dim + + @property + def time_embed_dim(self): + return self.time_input_dim * 4 + + @property + def cross_attention_dim(self): + return 100 + + @property + def dummy_tokenizer(self): + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + return tokenizer + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=self.text_embedder_hidden_size, + projection_dim=self.text_embedder_hidden_size, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModelWithProjection(config) + + @property + def dummy_image_encoder(self): + torch.manual_seed(0) + config = CLIPVisionConfig( + hidden_size=self.text_embedder_hidden_size, + projection_dim=self.text_embedder_hidden_size, + num_hidden_layers=5, + num_attention_heads=4, + image_size=32, + intermediate_size=37, + patch_size=1, + ) + return CLIPVisionModelWithProjection(config) + + @property + def dummy_text_proj(self): + torch.manual_seed(0) + + model_kwargs = { + "clip_embeddings_dim": self.text_embedder_hidden_size, + "time_embed_dim": self.time_embed_dim, + "cross_attention_dim": self.cross_attention_dim, + } + + model = UnCLIPTextProjModel(**model_kwargs) + return model + + @property + def dummy_decoder(self): + torch.manual_seed(0) + + model_kwargs = { + "sample_size": 64, + # RGB in channels + "in_channels": 3, + # Out channels is double in channels because predicts mean and variance + "out_channels": 6, + "down_block_types": ("ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D"), + "up_block_types": ("SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"), + "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", + "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2), + "layers_per_block": 1, + "cross_attention_dim": self.cross_attention_dim, + "attention_head_dim": 4, + "resnet_time_scale_shift": "scale_shift", + "class_embed_type": "identity", + } + + model = UNet2DConditionModel(**model_kwargs) + return model + + @property + def dummy_super_res_kwargs(self): + return { + "sample_size": 128, + "layers_per_block": 1, + "down_block_types": ("ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D"), + "up_block_types": ("ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D"), + "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2), + "in_channels": 6, + "out_channels": 3, + } + + @property + def dummy_super_res_first(self): + torch.manual_seed(0) + + model = UNet2DModel(**self.dummy_super_res_kwargs) + return model + + @property + def dummy_super_res_last(self): + # seeded differently to get different unet than `self.dummy_super_res_first` + torch.manual_seed(1) + + model = UNet2DModel(**self.dummy_super_res_kwargs) + return model + + def get_pipeline(self, device): + decoder = self.dummy_decoder + text_proj = self.dummy_text_proj + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + super_res_first = self.dummy_super_res_first + super_res_last = self.dummy_super_res_last + + decoder_scheduler = UnCLIPScheduler( + variance_type="learned_range", + prediction_type="epsilon", + num_train_timesteps=1000, + ) + + super_res_scheduler = UnCLIPScheduler( + variance_type="fixed_small_log", + prediction_type="epsilon", + num_train_timesteps=1000, + ) + + feature_extractor = CLIPImageProcessor(crop_size=32, size=32) + + image_encoder = self.dummy_image_encoder + + pipe = UnCLIPImageVariationPipeline( + decoder=decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_proj=text_proj, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + super_res_first=super_res_first, + super_res_last=super_res_last, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + return pipe + + def get_pipeline_inputs(self, device, seed, pil_image=False): + input_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + generator = torch.Generator(device=device).manual_seed(seed) + + if pil_image: + input_image = input_image * 0.5 + 0.5 + input_image = input_image.clamp(0, 1) + input_image = input_image.cpu().permute(0, 2, 3, 1).float().numpy() + input_image = DiffusionPipeline.numpy_to_pil(input_image)[0] + + return { + "image": input_image, + "generator": generator, + "decoder_num_inference_steps": 2, + "super_res_num_inference_steps": 2, + "output_type": "np", + } + + def test_unclip_image_variation_input_tensor(self): + device = "cpu" + seed = 0 + + pipe = self.get_pipeline(device) + + pipeline_inputs = self.get_pipeline_inputs(device, seed) + + output = pipe(**pipeline_inputs) + image = output.images + + tuple_pipeline_inputs = self.get_pipeline_inputs(device, seed) + + image_from_tuple = pipe( + **tuple_pipeline_inputs, + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 128, 128, 3) + + expected_slice = np.array( + [ + 0.9988, + 0.9997, + 0.9944, + 0.0003, + 0.0003, + 0.9974, + 0.0003, + 0.0004, + 0.9931, + ] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_unclip_image_variation_input_image(self): + device = "cpu" + seed = 0 + + pipe = self.get_pipeline(device) + + pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True) + + output = pipe(**pipeline_inputs) + image = output.images + + tuple_pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True) + + image_from_tuple = pipe( + **tuple_pipeline_inputs, + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 128, 128, 3) + + expected_slice = np.array( + [ + 0.9988, + 0.9997, + 0.9944, + 0.0003, + 0.0003, + 0.9974, + 0.0003, + 0.0004, + 0.9931, + ] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_unclip_image_variation_input_list_images(self): + device = "cpu" + seed = 0 + + pipe = self.get_pipeline(device) + + pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True) + pipeline_inputs["image"] = [ + pipeline_inputs["image"], + pipeline_inputs["image"], + ] + + output = pipe(**pipeline_inputs) + image = output.images + + tuple_pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True) + tuple_pipeline_inputs["image"] = [ + tuple_pipeline_inputs["image"], + tuple_pipeline_inputs["image"], + ] + + image_from_tuple = pipe( + **tuple_pipeline_inputs, + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (2, 128, 128, 3) + + expected_slice = np.array( + [ + 0.9997, + 0.9997, + 0.0003, + 0.0003, + 0.9950, + 0.0003, + 0.9993, + 0.9957, + 0.0004, + ] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_unclip_image_variation_input_num_images_per_prompt(self): + device = "cpu" + seed = 0 + + pipe = self.get_pipeline(device) + + pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True) + pipeline_inputs["image"] = [ + pipeline_inputs["image"], + pipeline_inputs["image"], + ] + + output = pipe(**pipeline_inputs, num_images_per_prompt=2) + image = output.images + + tuple_pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True) + tuple_pipeline_inputs["image"] = [ + tuple_pipeline_inputs["image"], + tuple_pipeline_inputs["image"], + ] + + image_from_tuple = pipe( + **tuple_pipeline_inputs, + num_images_per_prompt=2, + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (4, 128, 128, 3) + + expected_slice = np.array( + [ + 0.9997, + 0.9997, + 0.0008, + 0.9952, + 0.9980, + 0.9997, + 0.9961, + 0.9997, + 0.9995, + ] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_unclip_passed_image_embed(self): + device = torch.device("cpu") + seed = 0 + + class DummyScheduler: + init_noise_sigma = 1 + + pipe = self.get_pipeline(device) + + generator = torch.Generator(device=device).manual_seed(0) + dtype = pipe.decoder.dtype + batch_size = 1 + + shape = (batch_size, pipe.decoder.in_channels, pipe.decoder.sample_size, pipe.decoder.sample_size) + decoder_latents = pipe.prepare_latents( + shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() + ) + + shape = ( + batch_size, + pipe.super_res_first.in_channels // 2, + pipe.super_res_first.sample_size, + pipe.super_res_first.sample_size, + ) + super_res_latents = pipe.prepare_latents( + shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() + ) + + pipeline_inputs = self.get_pipeline_inputs(device, seed) + + img_out_1 = pipe( + **pipeline_inputs, decoder_latents=decoder_latents, super_res_latents=super_res_latents + ).images + + pipeline_inputs = self.get_pipeline_inputs(device, seed) + # Don't pass image, instead pass embedding + image = pipeline_inputs.pop("image") + image_embeddings = pipe.image_encoder(image).image_embeds + + img_out_2 = pipe( + **pipeline_inputs, + decoder_latents=decoder_latents, + super_res_latents=super_res_latents, + image_embeddings=image_embeddings, + ).images + + # make sure passing text embeddings manually is identical + assert np.abs(img_out_1 - img_out_2).max() < 1e-4 + + +@slow +@require_torch_gpu +class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_unclip_image_variation_karlo(self): + input_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unclip/cat.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/unclip/karlo_v1_alpha_cat_variation_fp16.npy" + ) + + pipeline = UnCLIPImageVariationPipeline.from_pretrained("fusing/karlo-image-variations-diffusers") + pipeline = pipeline.to(torch_device) + pipeline.set_progress_bar_config(disable=None) + pipeline.enable_sequential_cpu_offload() + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipeline( + input_image, + num_images_per_prompt=1, + generator=generator, + output_type="np", + ) + + image = output.images[0] + + assert image.shape == (256, 256, 3) + assert np.abs(expected_image - image).max() < 5e-2 diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 68ab914b4209..42f683f887fe 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -21,7 +21,7 @@ import numpy as np import torch -from diffusers.modeling_utils import ModelMixin +from diffusers.models import ModelMixin from diffusers.training_utils import EMAModel from diffusers.utils import torch_device @@ -70,9 +70,9 @@ def test_from_save_pretrained_dtype(self): with tempfile.TemporaryDirectory() as tmpdirname: model.to(dtype) model.save_pretrained(tmpdirname) - new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True) + new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype) assert new_model.dtype == dtype - new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False) + new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype) assert new_model.dtype == dtype def test_determinism(self): diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index ff02ee8ea4b4..3ed90748c547 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -33,6 +33,7 @@ DDIMScheduler, DDPMPipeline, DDPMScheduler, + DiffusionPipeline, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, @@ -45,7 +46,6 @@ UNet2DModel, logging, ) -from diffusers.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, nightly, slow, torch_device from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu @@ -704,7 +704,7 @@ def test_smart_download(self): def test_warning_unused_kwargs(self): model_id = "hf-internal-testing/unet-pipeline-dummy" - logger = logging.get_logger("diffusers.pipeline_utils") + logger = logging.get_logger("diffusers.pipelines") with tempfile.TemporaryDirectory() as tmpdirname: with CaptureLogger(logger) as cap_logger: DiffusionPipeline.from_pretrained(